You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/06 09:40:07 UTC
[spark] branch master updated: [SPARK-42002][CONNECT][PYTHON] Implement DataFrameWriterV2
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 537c04ff7ae [SPARK-42002][CONNECT][PYTHON] Implement DataFrameWriterV2
537c04ff7ae is described below
commit 537c04ff7ae8eddb0d52684a01ff7fa2ace91103
Author: Sandeep Singh <sa...@techaddict.me>
AuthorDate: Mon Feb 6 18:39:55 2023 +0900
[SPARK-42002][CONNECT][PYTHON] Implement DataFrameWriterV2
### What changes were proposed in this pull request?
Implement DataFrameWriterV2
### Why are the changes needed?
Parity with PySpark
### Does this PR introduce _any_ user-facing change?
Yes, implements DataFrameWriterV2
### How was this patch tested?
New UTs
Closes #39614 from techaddict/SPARK-42002.
Authored-by: Sandeep Singh <sa...@techaddict.me>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../src/main/protobuf/spark/connect/commands.proto | 42 ++++-
.../org/apache/spark/sql/connect/dsl/package.scala | 38 +++++
.../sql/connect/planner/SparkConnectPlanner.scala | 70 ++++++++
.../connect/planner/SparkConnectProtoSuite.scala | 188 ++++++++++++++++++++-
python/pyspark/sql/connect/dataframe.py | 9 +-
python/pyspark/sql/connect/plan.py | 65 +++++++
python/pyspark/sql/connect/proto/commands_pb2.py | 84 +++++++--
python/pyspark/sql/connect/proto/commands_pb2.pyi | 165 +++++++++++++++++-
python/pyspark/sql/connect/readwriter.py | 82 ++++++++-
python/pyspark/sql/dataframe.py | 3 +
python/pyspark/sql/readwriter.py | 3 +
.../sql/tests/connect/test_connect_basic.py | 24 ++-
12 files changed, 746 insertions(+), 27 deletions(-)
diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index ffacfc008a0..da81da844e7 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -18,6 +18,7 @@
syntax = 'proto3';
import "google/protobuf/any.proto";
+import "spark/connect/expressions.proto";
import "spark/connect/relations.proto";
import "spark/connect/types.proto";
@@ -33,6 +34,7 @@ message Command {
CreateScalarFunction create_function = 1;
WriteOperation write_operation = 2;
CreateDataFrameViewCommand create_dataframe_view = 3;
+ WriteOperationV2 write_operation_v2 = 4;
// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
@@ -140,4 +142,42 @@ message WriteOperation {
SAVE_MODE_ERROR_IF_EXISTS = 3;
SAVE_MODE_IGNORE = 4;
}
-}
\ No newline at end of file
+}
+
+// As writes are not directly handled during analysis and planning, they are modeled as commands.
+message WriteOperationV2 {
+ // (Required) The output of the `input` relation will be persisted according to the options.
+ Relation input = 1;
+
+ // The destination of the write operation must be either a path or a table.
+ string table_name = 2;
+
+ // A provider for the underlying output data source. Spark's default catalog supports
+ // "parquet", "json", etc.
+ string provider = 3;
+
+ // (Optional) List of columns for partitioning for output table created by `create`,
+ // `createOrReplace`, or `replace`
+ repeated Expression partitioning_columns = 4;
+
+ // (Optional) A list of configuration options.
+ map<string, string> options = 5;
+
+ // (Optional) A list of table properties.
+ map<string, string> table_properties = 6;
+
+ Mode mode = 7;
+
+ enum Mode {
+ MODE_UNSPECIFIED = 0;
+ MODE_CREATE = 1;
+ MODE_OVERWRITE = 2;
+ MODE_OVERWRITE_PARTITIONS = 3;
+ MODE_APPEND = 4;
+ MODE_REPLACE = 5;
+ MODE_CREATE_OR_REPLACE = 6;
+ }
+
+ // (Optional) A condition for overwrite saving mode
+ Expression overwrite_condition = 8;
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 1c98162c76e..88531286e24 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -242,6 +242,44 @@ package object dsl {
.setInput(logicalPlan))
.build()
}
+
+ def writeV2(
+ tableName: Option[String] = None,
+ provider: Option[String] = None,
+ options: Map[String, String] = Map.empty,
+ tableProperties: Map[String, String] = Map.empty,
+ partitionByCols: Seq[Expression] = Seq.empty,
+ mode: Option[String] = None,
+ overwriteCondition: Option[Expression] = None): Command = {
+ val writeOp = WriteOperationV2.newBuilder()
+ writeOp.setInput(logicalPlan)
+ tableName.foreach(writeOp.setTableName)
+ provider.foreach(writeOp.setProvider)
+ partitionByCols.foreach(writeOp.addPartitioningColumns)
+ options.foreach { case (k, v) =>
+ writeOp.putOptions(k, v)
+ }
+ tableProperties.foreach { case (k, v) =>
+ writeOp.putTableProperties(k, v)
+ }
+ mode.foreach { m =>
+ if (m == "MODE_CREATE") {
+ writeOp.setMode(WriteOperationV2.Mode.MODE_CREATE)
+ } else if (m == "MODE_OVERWRITE") {
+ writeOp.setMode(WriteOperationV2.Mode.MODE_OVERWRITE)
+ overwriteCondition.foreach(writeOp.setOverwriteCondition)
+ } else if (m == "MODE_OVERWRITE_PARTITIONS") {
+ writeOp.setMode(WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS)
+ } else if (m == "MODE_APPEND") {
+ writeOp.setMode(WriteOperationV2.Mode.MODE_APPEND)
+ } else if (m == "MODE_REPLACE") {
+ writeOp.setMode(WriteOperationV2.Mode.MODE_REPLACE)
+ } else if (m == "MODE_CREATE_OR_REPLACE") {
+ writeOp.setMode(WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
+ }
+ }
+ Command.newBuilder().setWriteOperationV2(writeOp.build()).build()
+ }
}
}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 51d115ef1ca..08df2274840 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.command.CreateViewCommand
import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.functions.{col, expr}
import org.apache.spark.sql.internal.CatalogImpl
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -1408,6 +1409,8 @@ class SparkConnectPlanner(val session: SparkSession) {
handleWriteOperation(command.getWriteOperation)
case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
handleCreateViewCommand(command.getCreateDataframeView)
+ case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
+ handleWriteOperationV2(command.getWriteOperationV2)
case proto.Command.CommandTypeCase.EXTENSION =>
handleCommandPlugin(command.getExtension)
case _ => throw new UnsupportedOperationException(s"$command not supported.")
@@ -1546,6 +1549,73 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ /**
+ * Transforms the write operation and executes it.
+ *
+ * The input write operation contains a reference to the input plan and transforms it to the
+ * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
+ * parameters of the WriteOperation into the corresponding methods calls.
+ *
+ * @param writeOperation
+ */
+ def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
+ // Transform the input plan into the logical plan.
+ val planner = new SparkConnectPlanner(session)
+ val plan = planner.transformRelation(writeOperation.getInput)
+ // And create a Dataset from the plan.
+ val dataset = Dataset.ofRows(session, logicalPlan = plan)
+
+ val w = dataset.writeTo(table = writeOperation.getTableName)
+
+ if (writeOperation.getOptionsCount > 0) {
+ writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
+ }
+
+ if (writeOperation.getTablePropertiesCount > 0) {
+ writeOperation.getTablePropertiesMap.asScala.foreach { case (key, value) =>
+ w.tableProperty(key, value)
+ }
+ }
+
+ if (writeOperation.getPartitioningColumnsCount > 0) {
+ val names = writeOperation.getPartitioningColumnsList.asScala
+ .map(transformExpression)
+ .map(Column(_))
+ .toSeq
+ w.partitionedBy(names.head, names.tail.toSeq: _*)
+ }
+
+ writeOperation.getMode match {
+ case proto.WriteOperationV2.Mode.MODE_CREATE =>
+ if (writeOperation.getProvider != null) {
+ w.using(writeOperation.getProvider).create()
+ } else {
+ w.create()
+ }
+ case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
+ w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
+ case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
+ w.overwritePartitions()
+ case proto.WriteOperationV2.Mode.MODE_APPEND =>
+ w.append()
+ case proto.WriteOperationV2.Mode.MODE_REPLACE =>
+ if (writeOperation.getProvider != null) {
+ w.using(writeOperation.getProvider).replace()
+ } else {
+ w.replace()
+ }
+ case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
+ if (writeOperation.getProvider != null) {
+ w.using(writeOperation.getProvider).createOrReplace()
+ } else {
+ w.createOrReplace()
+ }
+ case _ =>
+ throw new UnsupportedOperationException(
+ "WriteOperationV2:ModeValue not supported ${writeOperation.getModeValue}")
+ }
+ }
+
private val emptyLocalRelation = LocalRelation(
output = AttributeReference("value", StringType, false)() :: Nil,
data = Seq.empty)
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 4c4a070bb4f..87117801eb7 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.sql.connect.planner
import java.nio.file.{Files, Paths}
+import scala.collection.JavaConverters._
+
import com.google.protobuf.ByteString
import org.apache.spark.SparkClassNotFoundException
import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.Expression
import org.apache.spark.connect.proto.Join.JoinType
import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode}
import org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -34,9 +37,13 @@ import org.apache.spark.sql.connect.dsl.commands._
import org.apache.spark.sql.connect.dsl.expressions._
import org.apache.spark.sql.connect.dsl.plans._
import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
/**
* This suite is based on connect DSL and test that given same dataframe operations, whether
@@ -647,6 +654,185 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
}
}
+ test("WriteTo with create") {
+ withTable("testcat.table_name") {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+ val rows = Seq(
+ new GenericInternalRow(Array(1L, UTF8String.fromString("a"))),
+ new GenericInternalRow(Array(2L, UTF8String.fromString("b"))),
+ new GenericInternalRow(Array(3L, UTF8String.fromString("c"))))
+
+ val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+ val inputRows = rows.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+
+ val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+ val cmd = localRelationV2.writeV2(
+ tableName = Some("testcat.table_name"),
+ mode = Some("MODE_CREATE"))
+ transform(cmd)
+
+ val outputRows = spark.table("testcat.table_name").collect()
+ assert(outputRows.length == 3)
+ }
+ }
+
+ test("WriteTo with create and using") {
+ val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName())
+ withTable("testcat.table_name") {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+ val rows = Seq(
+ new GenericInternalRow(Array(1L, UTF8String.fromString("a"))),
+ new GenericInternalRow(Array(2L, UTF8String.fromString("b"))),
+ new GenericInternalRow(Array(3L, UTF8String.fromString("c"))))
+
+ val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+ val inputRows = rows.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+
+ val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+ val cmd = localRelationV2.writeV2(
+ tableName = Some("testcat.table_name"),
+ provider = Some("foo"),
+ mode = Some("MODE_CREATE"))
+ transform(cmd)
+
+ val outputRows = spark.table("testcat.table_name").collect()
+ assert(outputRows.length == 3)
+ val table = spark.sessionState.catalogManager
+ .catalog("testcat")
+ .asTableCatalog
+ .loadTable(Identifier.of(Array(), "table_name"))
+ assert(table.name === "testcat.table_name")
+ assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+ assert(table.partitioning.isEmpty)
+ assert(table.properties === (Map("provider" -> "foo") ++ defaultOwnership).asJava)
+ }
+ }
+
+ test("WriteTo with append") {
+ withTable("testcat.table_name") {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+ val rows = Seq(
+ new GenericInternalRow(Array(1L, UTF8String.fromString("a"))),
+ new GenericInternalRow(Array(2L, UTF8String.fromString("b"))),
+ new GenericInternalRow(Array(3L, UTF8String.fromString("c"))))
+
+ val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+ val inputRows = rows.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+
+ val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+ spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+ assert(spark.table("testcat.table_name").collect().isEmpty)
+
+ val cmd = localRelationV2.writeV2(
+ tableName = Some("testcat.table_name"),
+ mode = Some("MODE_APPEND"))
+ transform(cmd)
+
+ val outputRows = spark.table("testcat.table_name").collect()
+ assert(outputRows.length == 3)
+ }
+ }
+
+ test("WriteTo with overwrite") {
+ withTable("testcat.table_name") {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+ val rows1 = (1L to 3L).map { i =>
+ new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
+ }
+ val rows2 = (4L to 7L).map { i =>
+ new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
+ }
+
+ val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+ val inputRows1 = rows1.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+ val inputRows2 = rows2.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+
+ val localRelation1V2 = createLocalRelationProto(schema.toAttributes, inputRows1)
+ val localRelation2V2 = createLocalRelationProto(schema.toAttributes, inputRows2)
+
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+ assert(spark.table("testcat.table_name").collect().isEmpty)
+
+ val cmd1 = localRelation1V2.writeV2(
+ tableName = Some("testcat.table_name"),
+ mode = Some("MODE_APPEND"))
+ transform(cmd1)
+
+ val outputRows1 = spark.table("testcat.table_name").collect()
+ assert(outputRows1.length == 3)
+
+ val overwriteCondition = Expression
+ .newBuilder()
+ .setLiteral(Expression.Literal.newBuilder().setBoolean(true))
+ .build()
+
+ val cmd2 = localRelation2V2.writeV2(
+ tableName = Some("testcat.table_name"),
+ mode = Some("MODE_OVERWRITE"),
+ overwriteCondition = Some(overwriteCondition))
+ transform(cmd2)
+
+ val outputRows2 = spark.table("testcat.table_name").collect()
+ assert(outputRows2.length == 4)
+ }
+ }
+
+ test("WriteTo with overwritePartitions") {
+ withTable("testcat.table_name") {
+ spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+ val rows = (4L to 7L).map { i =>
+ new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
+ }
+
+ val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+ val inputRows = rows.map { row =>
+ val proj = UnsafeProjection.create(schema)
+ proj(row).copy()
+ }
+
+ val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+ spark.sql(
+ "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+ assert(spark.table("testcat.table_name").collect().isEmpty)
+
+ val cmd = localRelationV2.writeV2(
+ tableName = Some("testcat.table_name"),
+ mode = Some("MODE_OVERWRITE_PARTITIONS"))
+ transform(cmd)
+
+ val outputRows = spark.table("testcat.table_name").collect()
+ assert(outputRows.length == 4)
+ }
+ }
+
test("Test CreateView") {
withView("view1", "view2", "view3", "view4") {
transform(localRelation.createView("view1", global = true, replace = true))
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index d61938e6108..9b0c911c10f 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -52,7 +52,7 @@ from pyspark.sql.dataframe import (
from pyspark.errors import PySparkTypeError
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.group import GroupedData
-from pyspark.sql.connect.readwriter import DataFrameWriter
+from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.expressions import UnresolvedRegex
from pyspark.sql.connect.functions import (
@@ -1551,8 +1551,11 @@ class DataFrame:
def sameSemantics(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("sameSemantics() is not implemented.")
- def writeTo(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("writeTo() is not implemented.")
+ def writeTo(self, table: str) -> "DataFrameWriterV2":
+ assert self._plan is not None
+ return DataFrameWriterV2(self._plan, self._session, table)
+
+ writeTo.__doc__ = PySparkDataFrame.writeTo.__doc__
# SparkConnect specific API
def offset(self, n: int) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index d6b15c66c14..0945adf6d20 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1416,6 +1416,71 @@ class WriteOperation(LogicalPlan):
pass
+class WriteOperationV2(LogicalPlan):
+ def __init__(self, child: "LogicalPlan", table_name: str) -> None:
+ super(WriteOperationV2, self).__init__(child)
+ self.table_name: Optional[str] = table_name
+ self.provider: Optional[str] = None
+ self.partitioning_columns: List["ColumnOrName"] = []
+ self.options: dict[str, Optional[str]] = {}
+ self.table_properties: dict[str, Optional[str]] = {}
+ self.mode: Optional[str] = None
+ self.overwrite_condition: Optional["ColumnOrName"] = None
+
+ def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression:
+ if isinstance(col, Column):
+ return col.to_plan(session)
+ else:
+ return self.unresolved_attr(col)
+
+ def command(self, session: "SparkConnectClient") -> proto.Command:
+ assert self._child is not None
+ plan = proto.Command()
+ plan.write_operation_v2.input.CopyFrom(self._child.plan(session))
+ if self.table_name is not None:
+ plan.write_operation_v2.table_name = self.table_name
+ if self.provider is not None:
+ plan.write_operation_v2.provider = self.provider
+
+ plan.write_operation_v2.partitioning_columns.extend(
+ [self.col_to_expr(x, session) for x in self.partitioning_columns]
+ )
+
+ for k in self.options:
+ if self.options[k] is None:
+ plan.write_operation_v2.options.pop(k, None)
+ else:
+ plan.write_operation_v2.options[k] = cast(str, self.options[k])
+
+ for k in self.table_properties:
+ if self.table_properties[k] is None:
+ plan.write_operation_v2.table_properties.pop(k, None)
+ else:
+ plan.write_operation_v2.table_properties[k] = cast(str, self.table_properties[k])
+
+ if self.mode is not None:
+ wm = self.mode.lower()
+ if wm == "create":
+ plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE
+ elif wm == "overwrite":
+ plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE
+ elif wm == "overwrite_partition":
+ plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS
+ elif wm == "append":
+ plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_APPEND
+ elif wm == "replace":
+ plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE
+ if self.overwrite_condition is not None:
+ plan.write_operation_v2.overwrite_condition.CopyFrom(
+ self.col_to_expr(self.overwrite_condition, session)
+ )
+ elif wm == "create_or_replace":
+ plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE
+ else:
+ raise ValueError(f"Unknown Mode value for DataFrame: {self.mode}")
+ return plan
+
+
# Catalog API (internal-only)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index bc52ce6e763..482e7dd4fcc 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -30,12 +30,13 @@ _sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
+from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
- b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcb\x02\n\x07\x43ommand\x12N\n\x0f\x63reate_function\x18\x01 \x01(\x0b\x32#.spark.connect.CreateScalarFunctionH\x00R\x0e\x63reateFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFra [...]
+ b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9c\x03\n\x07\x43ommand\x12N\n\x0f\x63reate_function\x18\x01 \x01(\x0b\x32#.spark.connect.CreateScalarFunctionH\x00R\x0e\x63reateFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x0 [...]
)
@@ -45,10 +46,16 @@ _CREATEDATAFRAMEVIEWCOMMAND = DESCRIPTOR.message_types_by_name["CreateDataFrameV
_WRITEOPERATION = DESCRIPTOR.message_types_by_name["WriteOperation"]
_WRITEOPERATION_OPTIONSENTRY = _WRITEOPERATION.nested_types_by_name["OptionsEntry"]
_WRITEOPERATION_BUCKETBY = _WRITEOPERATION.nested_types_by_name["BucketBy"]
+_WRITEOPERATIONV2 = DESCRIPTOR.message_types_by_name["WriteOperationV2"]
+_WRITEOPERATIONV2_OPTIONSENTRY = _WRITEOPERATIONV2.nested_types_by_name["OptionsEntry"]
+_WRITEOPERATIONV2_TABLEPROPERTIESENTRY = _WRITEOPERATIONV2.nested_types_by_name[
+ "TablePropertiesEntry"
+]
_CREATESCALARFUNCTION_FUNCTIONLANGUAGE = _CREATESCALARFUNCTION.enum_types_by_name[
"FunctionLanguage"
]
_WRITEOPERATION_SAVEMODE = _WRITEOPERATION.enum_types_by_name["SaveMode"]
+_WRITEOPERATIONV2_MODE = _WRITEOPERATIONV2.enum_types_by_name["Mode"]
Command = _reflection.GeneratedProtocolMessageType(
"Command",
(_message.Message,),
@@ -113,26 +120,69 @@ _sym_db.RegisterMessage(WriteOperation)
_sym_db.RegisterMessage(WriteOperation.OptionsEntry)
_sym_db.RegisterMessage(WriteOperation.BucketBy)
+WriteOperationV2 = _reflection.GeneratedProtocolMessageType(
+ "WriteOperationV2",
+ (_message.Message,),
+ {
+ "OptionsEntry": _reflection.GeneratedProtocolMessageType(
+ "OptionsEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _WRITEOPERATIONV2_OPTIONSENTRY,
+ "__module__": "spark.connect.commands_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.WriteOperationV2.OptionsEntry)
+ },
+ ),
+ "TablePropertiesEntry": _reflection.GeneratedProtocolMessageType(
+ "TablePropertiesEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _WRITEOPERATIONV2_TABLEPROPERTIESENTRY,
+ "__module__": "spark.connect.commands_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.WriteOperationV2.TablePropertiesEntry)
+ },
+ ),
+ "DESCRIPTOR": _WRITEOPERATIONV2,
+ "__module__": "spark.connect.commands_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.WriteOperationV2)
+ },
+)
+_sym_db.RegisterMessage(WriteOperationV2)
+_sym_db.RegisterMessage(WriteOperationV2.OptionsEntry)
+_sym_db.RegisterMessage(WriteOperationV2.TablePropertiesEntry)
+
if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
_WRITEOPERATION_OPTIONSENTRY._options = None
_WRITEOPERATION_OPTIONSENTRY._serialized_options = b"8\001"
- _COMMAND._serialized_start = 133
- _COMMAND._serialized_end = 464
- _CREATESCALARFUNCTION._serialized_start = 467
- _CREATESCALARFUNCTION._serialized_end = 1002
- _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_start = 840
- _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_end = 979
- _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1005
- _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1155
- _WRITEOPERATION._serialized_start = 1158
- _WRITEOPERATION._serialized_end = 1900
- _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1596
- _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1654
- _WRITEOPERATION_BUCKETBY._serialized_start = 1656
- _WRITEOPERATION_BUCKETBY._serialized_end = 1747
- _WRITEOPERATION_SAVEMODE._serialized_start = 1750
- _WRITEOPERATION_SAVEMODE._serialized_end = 1887
+ _WRITEOPERATIONV2_OPTIONSENTRY._options = None
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_options = b"8\001"
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
+ _COMMAND._serialized_start = 166
+ _COMMAND._serialized_end = 578
+ _CREATESCALARFUNCTION._serialized_start = 581
+ _CREATESCALARFUNCTION._serialized_end = 1116
+ _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_start = 954
+ _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_end = 1093
+ _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1119
+ _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1269
+ _WRITEOPERATION._serialized_start = 1272
+ _WRITEOPERATION._serialized_end = 2014
+ _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1710
+ _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1768
+ _WRITEOPERATION_BUCKETBY._serialized_start = 1770
+ _WRITEOPERATION_BUCKETBY._serialized_end = 1861
+ _WRITEOPERATION_SAVEMODE._serialized_start = 1864
+ _WRITEOPERATION_SAVEMODE._serialized_end = 2001
+ _WRITEOPERATIONV2._serialized_start = 2017
+ _WRITEOPERATIONV2._serialized_end = 2812
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1710
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1768
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2584
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2650
+ _WRITEOPERATIONV2_MODE._serialized_start = 2653
+ _WRITEOPERATIONV2_MODE._serialized_end = 2812
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 2cebbf47c23..2447be6b467 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -40,6 +40,7 @@ import google.protobuf.descriptor
import google.protobuf.internal.containers
import google.protobuf.internal.enum_type_wrapper
import google.protobuf.message
+import pyspark.sql.connect.proto.expressions_pb2
import pyspark.sql.connect.proto.relations_pb2
import pyspark.sql.connect.proto.types_pb2
import sys
@@ -62,6 +63,7 @@ class Command(google.protobuf.message.Message):
CREATE_FUNCTION_FIELD_NUMBER: builtins.int
WRITE_OPERATION_FIELD_NUMBER: builtins.int
CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
+ WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
EXTENSION_FIELD_NUMBER: builtins.int
@property
def create_function(self) -> global___CreateScalarFunction: ...
@@ -70,6 +72,8 @@ class Command(google.protobuf.message.Message):
@property
def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ...
@property
+ def write_operation_v2(self) -> global___WriteOperationV2: ...
+ @property
def extension(self) -> google.protobuf.any_pb2.Any:
"""This field is used to mark extensions to the protocol. When plugins generate arbitrary
Commands they can add them here. During the planning the correct resolution is done.
@@ -80,6 +84,7 @@ class Command(google.protobuf.message.Message):
create_function: global___CreateScalarFunction | None = ...,
write_operation: global___WriteOperation | None = ...,
create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
+ write_operation_v2: global___WriteOperationV2 | None = ...,
extension: google.protobuf.any_pb2.Any | None = ...,
) -> None: ...
def HasField(
@@ -95,6 +100,8 @@ class Command(google.protobuf.message.Message):
b"extension",
"write_operation",
b"write_operation",
+ "write_operation_v2",
+ b"write_operation_v2",
],
) -> builtins.bool: ...
def ClearField(
@@ -110,12 +117,18 @@ class Command(google.protobuf.message.Message):
b"extension",
"write_operation",
b"write_operation",
+ "write_operation_v2",
+ b"write_operation_v2",
],
) -> None: ...
def WhichOneof(
self, oneof_group: typing_extensions.Literal["command_type", b"command_type"]
) -> typing_extensions.Literal[
- "create_function", "write_operation", "create_dataframe_view", "extension"
+ "create_function",
+ "write_operation",
+ "create_dataframe_view",
+ "write_operation_v2",
+ "extension",
] | None: ...
global___Command = Command
@@ -441,3 +454,153 @@ class WriteOperation(google.protobuf.message.Message):
) -> typing_extensions.Literal["path", "table_name"] | None: ...
global___WriteOperation = WriteOperation
+
+class WriteOperationV2(google.protobuf.message.Message):
+ """As writes are not directly handled during analysis and planning, they are modeled as commands."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class _Mode:
+ ValueType = typing.NewType("ValueType", builtins.int)
+ V: typing_extensions.TypeAlias = ValueType
+
+ class _ModeEnumTypeWrapper(
+ google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+ WriteOperationV2._Mode.ValueType
+ ],
+ builtins.type,
+ ): # noqa: F821
+ DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+ MODE_UNSPECIFIED: WriteOperationV2._Mode.ValueType # 0
+ MODE_CREATE: WriteOperationV2._Mode.ValueType # 1
+ MODE_OVERWRITE: WriteOperationV2._Mode.ValueType # 2
+ MODE_OVERWRITE_PARTITIONS: WriteOperationV2._Mode.ValueType # 3
+ MODE_APPEND: WriteOperationV2._Mode.ValueType # 4
+ MODE_REPLACE: WriteOperationV2._Mode.ValueType # 5
+ MODE_CREATE_OR_REPLACE: WriteOperationV2._Mode.ValueType # 6
+
+ class Mode(_Mode, metaclass=_ModeEnumTypeWrapper): ...
+ MODE_UNSPECIFIED: WriteOperationV2.Mode.ValueType # 0
+ MODE_CREATE: WriteOperationV2.Mode.ValueType # 1
+ MODE_OVERWRITE: WriteOperationV2.Mode.ValueType # 2
+ MODE_OVERWRITE_PARTITIONS: WriteOperationV2.Mode.ValueType # 3
+ MODE_APPEND: WriteOperationV2.Mode.ValueType # 4
+ MODE_REPLACE: WriteOperationV2.Mode.ValueType # 5
+ MODE_CREATE_OR_REPLACE: WriteOperationV2.Mode.ValueType # 6
+
+ class OptionsEntry(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ key: builtins.str
+ value: builtins.str
+ def __init__(
+ self,
+ *,
+ key: builtins.str = ...,
+ value: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+ ) -> None: ...
+
+ class TablePropertiesEntry(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ key: builtins.str
+ value: builtins.str
+ def __init__(
+ self,
+ *,
+ key: builtins.str = ...,
+ value: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+ ) -> None: ...
+
+ INPUT_FIELD_NUMBER: builtins.int
+ TABLE_NAME_FIELD_NUMBER: builtins.int
+ PROVIDER_FIELD_NUMBER: builtins.int
+ PARTITIONING_COLUMNS_FIELD_NUMBER: builtins.int
+ OPTIONS_FIELD_NUMBER: builtins.int
+ TABLE_PROPERTIES_FIELD_NUMBER: builtins.int
+ MODE_FIELD_NUMBER: builtins.int
+ OVERWRITE_CONDITION_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
+ """(Required) The output of the `input` relation will be persisted according to the options."""
+ table_name: builtins.str
+ """The destination of the write operation must be either a path or a table."""
+ provider: builtins.str
+ """A provider for the underlying output data source. Spark's default catalog supports
+ "parquet", "json", etc.
+ """
+ @property
+ def partitioning_columns(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Optional) List of columns for partitioning for output table created by `create`,
+ `createOrReplace`, or `replace`
+ """
+ @property
+ def options(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+ """(Optional) A list of configuration options."""
+ @property
+ def table_properties(
+ self,
+ ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+ """(Optional) A list of table properties."""
+ mode: global___WriteOperationV2.Mode.ValueType
+ @property
+ def overwrite_condition(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression:
+ """(Optional) A condition for overwrite saving mode"""
+ def __init__(
+ self,
+ *,
+ input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
+ table_name: builtins.str = ...,
+ provider: builtins.str = ...,
+ partitioning_columns: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+ table_properties: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+ mode: global___WriteOperationV2.Mode.ValueType = ...,
+ overwrite_condition: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "input", b"input", "overwrite_condition", b"overwrite_condition"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "input",
+ b"input",
+ "mode",
+ b"mode",
+ "options",
+ b"options",
+ "overwrite_condition",
+ b"overwrite_condition",
+ "partitioning_columns",
+ b"partitioning_columns",
+ "provider",
+ b"provider",
+ "table_name",
+ b"table_name",
+ "table_properties",
+ b"table_properties",
+ ],
+ ) -> None: ...
+
+global___WriteOperationV2 = WriteOperationV2
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index f7a3fc1344c..8724348592e 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -22,17 +22,18 @@ from typing import Dict
from typing import Optional, Union, List, overload, Tuple, cast, Any
from typing import TYPE_CHECKING
-from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation
+from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation, WriteOperationV2
from pyspark.sql.types import StructType
from pyspark.sql.utils import to_str
from pyspark.sql.readwriter import (
DataFrameWriter as PySparkDataFrameWriter,
DataFrameReader as PySparkDataFrameReader,
+ DataFrameWriterV2 as PySparkDataFrameWriterV2,
)
if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame
- from pyspark.sql.connect._typing import OptionalPrimitiveType
+ from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType
from pyspark.sql.connect.session import SparkSession
__all__ = ["DataFrameReader", "DataFrameWriter"]
@@ -603,6 +604,83 @@ class DataFrameWriter(OptionUtils):
raise NotImplementedError("jdbc() not supported for DataFrameWriter")
+class DataFrameWriterV2(OptionUtils):
+ def __init__(self, plan: "LogicalPlan", session: "SparkSession", table: str):
+ self._df: "LogicalPlan" = plan
+ self._spark: "SparkSession" = session
+ self._table_name: str = table
+ self._write: "WriteOperationV2" = WriteOperationV2(self._df, self._table_name)
+
+ def using(self, provider: str) -> "DataFrameWriterV2":
+ self._write.provider = provider
+ return self
+
+ using.__doc__ = PySparkDataFrameWriterV2.using.__doc__
+
+ def option(self, key: str, value: "OptionalPrimitiveType") -> "DataFrameWriterV2":
+ self._write.options[key] = to_str(value)
+ return self
+
+ option.__doc__ = PySparkDataFrameWriterV2.option.__doc__
+
+ def options(self, **options: "OptionalPrimitiveType") -> "DataFrameWriterV2":
+ for k in options:
+ self._write.options[k] = to_str(options[k])
+ return self
+
+ options.__doc__ = PySparkDataFrameWriterV2.options.__doc__
+
+ def tableProperty(self, property: str, value: str) -> "DataFrameWriterV2":
+ self._write.table_properties[property] = value
+ return self
+
+ tableProperty.__doc__ = PySparkDataFrameWriterV2.tableProperty.__doc__
+
+ def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFrameWriterV2":
+ self._write.partitioning_columns = [col]
+ self._write.partitioning_columns.extend(cols)
+ return self
+
+ partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__
+
+ def create(self) -> None:
+ self._write.mode = "create"
+ self._spark.client.execute_command(self._write.command(self._spark.client))
+
+ create.__doc__ = PySparkDataFrameWriterV2.create.__doc__
+
+ def replace(self) -> None:
+ self._write.mode = "replace"
+ self._spark.client.execute_command(self._write.command(self._spark.client))
+
+ replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__
+
+ def createOrReplace(self) -> None:
+ self._write.mode = "create_or_replace"
+ self._spark.client.execute_command(self._write.command(self._spark.client))
+
+ createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__
+
+ def append(self) -> None:
+ self._write.mode = "append"
+ self._spark.client.execute_command(self._write.command(self._spark.client))
+
+ append.__doc__ = PySparkDataFrameWriterV2.append.__doc__
+
+ def overwrite(self, condition: "ColumnOrName") -> None:
+ self._write.mode = "overwrite"
+ self._write.overwrite_condition = condition
+ self._spark.client.execute_command(self._write.command(self._spark.client))
+
+ overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__
+
+ def overwritePartitions(self) -> None:
+ self._write.mode = "overwrite_partitions"
+ self._spark.client.execute_command(self._write.command(self._spark.client))
+
+ overwritePartitions.__doc__ = PySparkDataFrameWriterV2.overwritePartitions.__doc__
+
+
def _test() -> None:
import sys
import doctest
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f9d9c199faf..77a35ca8240 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -5142,6 +5142,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
.. versionadded:: 3.1.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
table : str
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 8b083ae9054..b87fb6528bb 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -1921,6 +1921,9 @@ class DataFrameWriterV2:
to external storage using the v2 API.
.. versionadded:: 3.1.0
+
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
"""
def __init__(self, df: "DataFrame", table: str):
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 08fad856036..8328f3181a3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -62,6 +62,7 @@ if should_test_connect:
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.sql.connect.client import ChannelBuilder
from pyspark.sql.connect.column import Column
+ from pyspark.sql.connect.readwriter import DataFrameWriterV2
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
from pyspark.sql.connect.function_builder import udf
@@ -1982,6 +1983,27 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
ndf = self.connect.read.table("parquet_test")
self.assertEqual(set(df.collect()), set(ndf.collect()))
+ def test_writeTo_operations(self):
+ # SPARK-42002: Implement DataFrameWriterV2
+ import datetime
+ from pyspark.sql.connect.functions import col, years, months, days, hours, bucket
+
+ df = self.connect.createDataFrame(
+ [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
+ )
+ writer = df.writeTo("table1")
+ self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2)
+ self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2)
+ self.assertIsInstance(writer.using("source"), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2)
+ self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2)
+ self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours("ts")), DataFrameWriterV2)
+
def test_agg_with_avg(self):
# SPARK-41325: groupby.avg()
df = (
@@ -2629,7 +2651,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
"_repr_html_",
"semanticHash",
"sameSemantics",
- "writeTo",
):
with self.assertRaises(NotImplementedError):
getattr(df, f)()
@@ -2681,7 +2702,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
def test_unsupported_io_functions(self):
# SPARK-41964: Disable unsupported functions.
- # DataFrameWriterV2 is also not implemented yet
df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
for f in ("jdbc",):
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org