You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2023/02/06 09:40:07 UTC

[spark] branch master updated: [SPARK-42002][CONNECT][PYTHON] Implement DataFrameWriterV2

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 537c04ff7ae [SPARK-42002][CONNECT][PYTHON] Implement DataFrameWriterV2
537c04ff7ae is described below

commit 537c04ff7ae8eddb0d52684a01ff7fa2ace91103
Author: Sandeep Singh <sa...@techaddict.me>
AuthorDate: Mon Feb 6 18:39:55 2023 +0900

    [SPARK-42002][CONNECT][PYTHON] Implement DataFrameWriterV2
    
    ### What changes were proposed in this pull request?
    Implement DataFrameWriterV2
    
    ### Why are the changes needed?
    Parity with PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, implements DataFrameWriterV2
    
    ### How was this patch tested?
    New UTs
    
    Closes #39614 from techaddict/SPARK-42002.
    
    Authored-by: Sandeep Singh <sa...@techaddict.me>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../src/main/protobuf/spark/connect/commands.proto |  42 ++++-
 .../org/apache/spark/sql/connect/dsl/package.scala |  38 +++++
 .../sql/connect/planner/SparkConnectPlanner.scala  |  70 ++++++++
 .../connect/planner/SparkConnectProtoSuite.scala   | 188 ++++++++++++++++++++-
 python/pyspark/sql/connect/dataframe.py            |   9 +-
 python/pyspark/sql/connect/plan.py                 |  65 +++++++
 python/pyspark/sql/connect/proto/commands_pb2.py   |  84 +++++++--
 python/pyspark/sql/connect/proto/commands_pb2.pyi  | 165 +++++++++++++++++-
 python/pyspark/sql/connect/readwriter.py           |  82 ++++++++-
 python/pyspark/sql/dataframe.py                    |   3 +
 python/pyspark/sql/readwriter.py                   |   3 +
 .../sql/tests/connect/test_connect_basic.py        |  24 ++-
 12 files changed, 746 insertions(+), 27 deletions(-)

diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index ffacfc008a0..da81da844e7 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -18,6 +18,7 @@
 syntax = 'proto3';
 
 import "google/protobuf/any.proto";
+import "spark/connect/expressions.proto";
 import "spark/connect/relations.proto";
 import "spark/connect/types.proto";
 
@@ -33,6 +34,7 @@ message Command {
     CreateScalarFunction create_function = 1;
     WriteOperation write_operation = 2;
     CreateDataFrameViewCommand create_dataframe_view = 3;
+    WriteOperationV2 write_operation_v2 = 4;
 
     // This field is used to mark extensions to the protocol. When plugins generate arbitrary
     // Commands they can add them here. During the planning the correct resolution is done.
@@ -140,4 +142,42 @@ message WriteOperation {
     SAVE_MODE_ERROR_IF_EXISTS = 3;
     SAVE_MODE_IGNORE = 4;
   }
-}
\ No newline at end of file
+}
+
+// As writes are not directly handled during analysis and planning, they are modeled as commands.
+message WriteOperationV2 {
+  // (Required) The output of the `input` relation will be persisted according to the options.
+  Relation input = 1;
+
+  // The destination of the write operation must be either a path or a table.
+  string table_name = 2;
+
+  // A provider for the underlying output data source. Spark's default catalog supports
+  // "parquet", "json", etc.
+  string provider = 3;
+
+  // (Optional) List of columns for partitioning for output table created by `create`,
+  // `createOrReplace`, or `replace`
+  repeated Expression partitioning_columns = 4;
+
+  // (Optional) A list of configuration options.
+  map<string, string> options = 5;
+
+  // (Optional) A list of table properties.
+  map<string, string> table_properties = 6;
+
+  Mode mode = 7;
+
+  enum Mode {
+    MODE_UNSPECIFIED = 0;
+    MODE_CREATE = 1;
+    MODE_OVERWRITE = 2;
+    MODE_OVERWRITE_PARTITIONS = 3;
+    MODE_APPEND = 4;
+    MODE_REPLACE = 5;
+    MODE_CREATE_OR_REPLACE = 6;
+  }
+
+  // (Optional) A condition for overwrite saving mode
+  Expression overwrite_condition = 8;
+}
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
index 1c98162c76e..88531286e24 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/dsl/package.scala
@@ -242,6 +242,44 @@ package object dsl {
               .setInput(logicalPlan))
           .build()
       }
+
+      def writeV2(
+          tableName: Option[String] = None,
+          provider: Option[String] = None,
+          options: Map[String, String] = Map.empty,
+          tableProperties: Map[String, String] = Map.empty,
+          partitionByCols: Seq[Expression] = Seq.empty,
+          mode: Option[String] = None,
+          overwriteCondition: Option[Expression] = None): Command = {
+        val writeOp = WriteOperationV2.newBuilder()
+        writeOp.setInput(logicalPlan)
+        tableName.foreach(writeOp.setTableName)
+        provider.foreach(writeOp.setProvider)
+        partitionByCols.foreach(writeOp.addPartitioningColumns)
+        options.foreach { case (k, v) =>
+          writeOp.putOptions(k, v)
+        }
+        tableProperties.foreach { case (k, v) =>
+          writeOp.putTableProperties(k, v)
+        }
+        mode.foreach { m =>
+          if (m == "MODE_CREATE") {
+            writeOp.setMode(WriteOperationV2.Mode.MODE_CREATE)
+          } else if (m == "MODE_OVERWRITE") {
+            writeOp.setMode(WriteOperationV2.Mode.MODE_OVERWRITE)
+            overwriteCondition.foreach(writeOp.setOverwriteCondition)
+          } else if (m == "MODE_OVERWRITE_PARTITIONS") {
+            writeOp.setMode(WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS)
+          } else if (m == "MODE_APPEND") {
+            writeOp.setMode(WriteOperationV2.Mode.MODE_APPEND)
+          } else if (m == "MODE_REPLACE") {
+            writeOp.setMode(WriteOperationV2.Mode.MODE_REPLACE)
+          } else if (m == "MODE_CREATE_OR_REPLACE") {
+            writeOp.setMode(WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
+          }
+        }
+        Command.newBuilder().setWriteOperationV2(writeOp.build()).build()
+      }
     }
   }
 
diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 51d115ef1ca..08df2274840 100644
--- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,6 +45,7 @@ import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.execution.command.CreateViewCommand
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.functions.{col, expr}
 import org.apache.spark.sql.internal.CatalogImpl
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
@@ -1408,6 +1409,8 @@ class SparkConnectPlanner(val session: SparkSession) {
         handleWriteOperation(command.getWriteOperation)
       case proto.Command.CommandTypeCase.CREATE_DATAFRAME_VIEW =>
         handleCreateViewCommand(command.getCreateDataframeView)
+      case proto.Command.CommandTypeCase.WRITE_OPERATION_V2 =>
+        handleWriteOperationV2(command.getWriteOperationV2)
       case proto.Command.CommandTypeCase.EXTENSION =>
         handleCommandPlugin(command.getExtension)
       case _ => throw new UnsupportedOperationException(s"$command not supported.")
@@ -1546,6 +1549,73 @@ class SparkConnectPlanner(val session: SparkSession) {
     }
   }
 
+  /**
+   * Transforms the write operation and executes it.
+   *
+   * The input write operation contains a reference to the input plan and transforms it to the
+   * corresponding logical plan. Afterwards, creates the DataFrameWriter and translates the
+   * parameters of the WriteOperation into the corresponding methods calls.
+   *
+   * @param writeOperation
+   */
+  def handleWriteOperationV2(writeOperation: proto.WriteOperationV2): Unit = {
+    // Transform the input plan into the logical plan.
+    val planner = new SparkConnectPlanner(session)
+    val plan = planner.transformRelation(writeOperation.getInput)
+    // And create a Dataset from the plan.
+    val dataset = Dataset.ofRows(session, logicalPlan = plan)
+
+    val w = dataset.writeTo(table = writeOperation.getTableName)
+
+    if (writeOperation.getOptionsCount > 0) {
+      writeOperation.getOptionsMap.asScala.foreach { case (key, value) => w.option(key, value) }
+    }
+
+    if (writeOperation.getTablePropertiesCount > 0) {
+      writeOperation.getTablePropertiesMap.asScala.foreach { case (key, value) =>
+        w.tableProperty(key, value)
+      }
+    }
+
+    if (writeOperation.getPartitioningColumnsCount > 0) {
+      val names = writeOperation.getPartitioningColumnsList.asScala
+        .map(transformExpression)
+        .map(Column(_))
+        .toSeq
+      w.partitionedBy(names.head, names.tail.toSeq: _*)
+    }
+
+    writeOperation.getMode match {
+      case proto.WriteOperationV2.Mode.MODE_CREATE =>
+        if (writeOperation.getProvider != null) {
+          w.using(writeOperation.getProvider).create()
+        } else {
+          w.create()
+        }
+      case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
+        w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
+      case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
+        w.overwritePartitions()
+      case proto.WriteOperationV2.Mode.MODE_APPEND =>
+        w.append()
+      case proto.WriteOperationV2.Mode.MODE_REPLACE =>
+        if (writeOperation.getProvider != null) {
+          w.using(writeOperation.getProvider).replace()
+        } else {
+          w.replace()
+        }
+      case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE =>
+        if (writeOperation.getProvider != null) {
+          w.using(writeOperation.getProvider).createOrReplace()
+        } else {
+          w.createOrReplace()
+        }
+      case _ =>
+        throw new UnsupportedOperationException(
+          "WriteOperationV2:ModeValue not supported ${writeOperation.getModeValue}")
+    }
+  }
+
   private val emptyLocalRelation = LocalRelation(
     output = AttributeReference("value", StringType, false)() :: Nil,
     data = Seq.empty)
diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
index 4c4a070bb4f..87117801eb7 100644
--- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
+++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.sql.connect.planner
 
 import java.nio.file.{Files, Paths}
 
+import scala.collection.JavaConverters._
+
 import com.google.protobuf.ByteString
 
 import org.apache.spark.SparkClassNotFoundException
 import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.Expression
 import org.apache.spark.connect.proto.Join.JoinType
 import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Row, SaveMode}
 import org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, UnsafeProjection}
 import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -34,9 +37,13 @@ import org.apache.spark.sql.connect.dsl.commands._
 import org.apache.spark.sql.connect.dsl.expressions._
 import org.apache.spark.sql.connect.dsl.plans._
 import org.apache.spark.sql.connect.planner.LiteralValueProtoConverter.toConnectProtoValue
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, TableCatalog}
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
 import org.apache.spark.sql.execution.arrow.ArrowConverters
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MapType, Metadata, ShortType, StringType, StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
 
 /**
  * This suite is based on connect DSL and test that given same dataframe operations, whether
@@ -647,6 +654,185 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest {
     }
   }
 
+  test("WriteTo with create") {
+    withTable("testcat.table_name") {
+      spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+      val rows = Seq(
+        new GenericInternalRow(Array(1L, UTF8String.fromString("a"))),
+        new GenericInternalRow(Array(2L, UTF8String.fromString("b"))),
+        new GenericInternalRow(Array(3L, UTF8String.fromString("c"))))
+
+      val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+      val inputRows = rows.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+
+      val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+      val cmd = localRelationV2.writeV2(
+        tableName = Some("testcat.table_name"),
+        mode = Some("MODE_CREATE"))
+      transform(cmd)
+
+      val outputRows = spark.table("testcat.table_name").collect()
+      assert(outputRows.length == 3)
+    }
+  }
+
+  test("WriteTo with create and using") {
+    val defaultOwnership = Map(TableCatalog.PROP_OWNER -> Utils.getCurrentUserName())
+    withTable("testcat.table_name") {
+      spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+      val rows = Seq(
+        new GenericInternalRow(Array(1L, UTF8String.fromString("a"))),
+        new GenericInternalRow(Array(2L, UTF8String.fromString("b"))),
+        new GenericInternalRow(Array(3L, UTF8String.fromString("c"))))
+
+      val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+      val inputRows = rows.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+
+      val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+      val cmd = localRelationV2.writeV2(
+        tableName = Some("testcat.table_name"),
+        provider = Some("foo"),
+        mode = Some("MODE_CREATE"))
+      transform(cmd)
+
+      val outputRows = spark.table("testcat.table_name").collect()
+      assert(outputRows.length == 3)
+      val table = spark.sessionState.catalogManager
+        .catalog("testcat")
+        .asTableCatalog
+        .loadTable(Identifier.of(Array(), "table_name"))
+      assert(table.name === "testcat.table_name")
+      assert(table.schema === new StructType().add("id", LongType).add("data", StringType))
+      assert(table.partitioning.isEmpty)
+      assert(table.properties === (Map("provider" -> "foo") ++ defaultOwnership).asJava)
+    }
+  }
+
+  test("WriteTo with append") {
+    withTable("testcat.table_name") {
+      spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+      val rows = Seq(
+        new GenericInternalRow(Array(1L, UTF8String.fromString("a"))),
+        new GenericInternalRow(Array(2L, UTF8String.fromString("b"))),
+        new GenericInternalRow(Array(3L, UTF8String.fromString("c"))))
+
+      val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+      val inputRows = rows.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+
+      val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+      spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo")
+
+      assert(spark.table("testcat.table_name").collect().isEmpty)
+
+      val cmd = localRelationV2.writeV2(
+        tableName = Some("testcat.table_name"),
+        mode = Some("MODE_APPEND"))
+      transform(cmd)
+
+      val outputRows = spark.table("testcat.table_name").collect()
+      assert(outputRows.length == 3)
+    }
+  }
+
+  test("WriteTo with overwrite") {
+    withTable("testcat.table_name") {
+      spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+      val rows1 = (1L to 3L).map { i =>
+        new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
+      }
+      val rows2 = (4L to 7L).map { i =>
+        new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
+      }
+
+      val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+      val inputRows1 = rows1.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+      val inputRows2 = rows2.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+
+      val localRelation1V2 = createLocalRelationProto(schema.toAttributes, inputRows1)
+      val localRelation2V2 = createLocalRelationProto(schema.toAttributes, inputRows2)
+
+      spark.sql(
+        "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+      assert(spark.table("testcat.table_name").collect().isEmpty)
+
+      val cmd1 = localRelation1V2.writeV2(
+        tableName = Some("testcat.table_name"),
+        mode = Some("MODE_APPEND"))
+      transform(cmd1)
+
+      val outputRows1 = spark.table("testcat.table_name").collect()
+      assert(outputRows1.length == 3)
+
+      val overwriteCondition = Expression
+        .newBuilder()
+        .setLiteral(Expression.Literal.newBuilder().setBoolean(true))
+        .build()
+
+      val cmd2 = localRelation2V2.writeV2(
+        tableName = Some("testcat.table_name"),
+        mode = Some("MODE_OVERWRITE"),
+        overwriteCondition = Some(overwriteCondition))
+      transform(cmd2)
+
+      val outputRows2 = spark.table("testcat.table_name").collect()
+      assert(outputRows2.length == 4)
+    }
+  }
+
+  test("WriteTo with overwritePartitions") {
+    withTable("testcat.table_name") {
+      spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+
+      val rows = (4L to 7L).map { i =>
+        new GenericInternalRow(Array(i, UTF8String.fromString("" + (i - 1 + 'a').toChar)))
+      }
+
+      val schema = StructType(Array(StructField("id", LongType), StructField("data", StringType)))
+      val inputRows = rows.map { row =>
+        val proj = UnsafeProjection.create(schema)
+        proj(row).copy()
+      }
+
+      val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows)
+
+      spark.sql(
+        "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)")
+
+      assert(spark.table("testcat.table_name").collect().isEmpty)
+
+      val cmd = localRelationV2.writeV2(
+        tableName = Some("testcat.table_name"),
+        mode = Some("MODE_OVERWRITE_PARTITIONS"))
+      transform(cmd)
+
+      val outputRows = spark.table("testcat.table_name").collect()
+      assert(outputRows.length == 4)
+    }
+  }
+
   test("Test CreateView") {
     withView("view1", "view2", "view3", "view4") {
       transform(localRelation.createView("view1", global = true, replace = true))
diff --git a/python/pyspark/sql/connect/dataframe.py b/python/pyspark/sql/connect/dataframe.py
index d61938e6108..9b0c911c10f 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -52,7 +52,7 @@ from pyspark.sql.dataframe import (
 from pyspark.errors import PySparkTypeError
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
-from pyspark.sql.connect.readwriter import DataFrameWriter
+from pyspark.sql.connect.readwriter import DataFrameWriter, DataFrameWriterV2
 from pyspark.sql.connect.column import Column
 from pyspark.sql.connect.expressions import UnresolvedRegex
 from pyspark.sql.connect.functions import (
@@ -1551,8 +1551,11 @@ class DataFrame:
     def sameSemantics(self, *args: Any, **kwargs: Any) -> None:
         raise NotImplementedError("sameSemantics() is not implemented.")
 
-    def writeTo(self, *args: Any, **kwargs: Any) -> None:
-        raise NotImplementedError("writeTo() is not implemented.")
+    def writeTo(self, table: str) -> "DataFrameWriterV2":
+        assert self._plan is not None
+        return DataFrameWriterV2(self._plan, self._session, table)
+
+    writeTo.__doc__ = PySparkDataFrame.writeTo.__doc__
 
     # SparkConnect specific API
     def offset(self, n: int) -> "DataFrame":
diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py
index d6b15c66c14..0945adf6d20 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1416,6 +1416,71 @@ class WriteOperation(LogicalPlan):
         pass
 
 
+class WriteOperationV2(LogicalPlan):
+    def __init__(self, child: "LogicalPlan", table_name: str) -> None:
+        super(WriteOperationV2, self).__init__(child)
+        self.table_name: Optional[str] = table_name
+        self.provider: Optional[str] = None
+        self.partitioning_columns: List["ColumnOrName"] = []
+        self.options: dict[str, Optional[str]] = {}
+        self.table_properties: dict[str, Optional[str]] = {}
+        self.mode: Optional[str] = None
+        self.overwrite_condition: Optional["ColumnOrName"] = None
+
+    def col_to_expr(self, col: "ColumnOrName", session: "SparkConnectClient") -> proto.Expression:
+        if isinstance(col, Column):
+            return col.to_plan(session)
+        else:
+            return self.unresolved_attr(col)
+
+    def command(self, session: "SparkConnectClient") -> proto.Command:
+        assert self._child is not None
+        plan = proto.Command()
+        plan.write_operation_v2.input.CopyFrom(self._child.plan(session))
+        if self.table_name is not None:
+            plan.write_operation_v2.table_name = self.table_name
+        if self.provider is not None:
+            plan.write_operation_v2.provider = self.provider
+
+        plan.write_operation_v2.partitioning_columns.extend(
+            [self.col_to_expr(x, session) for x in self.partitioning_columns]
+        )
+
+        for k in self.options:
+            if self.options[k] is None:
+                plan.write_operation_v2.options.pop(k, None)
+            else:
+                plan.write_operation_v2.options[k] = cast(str, self.options[k])
+
+        for k in self.table_properties:
+            if self.table_properties[k] is None:
+                plan.write_operation_v2.table_properties.pop(k, None)
+            else:
+                plan.write_operation_v2.table_properties[k] = cast(str, self.table_properties[k])
+
+        if self.mode is not None:
+            wm = self.mode.lower()
+            if wm == "create":
+                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE
+            elif wm == "overwrite":
+                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE
+            elif wm == "overwrite_partition":
+                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS
+            elif wm == "append":
+                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_APPEND
+            elif wm == "replace":
+                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_REPLACE
+                if self.overwrite_condition is not None:
+                    plan.write_operation_v2.overwrite_condition.CopyFrom(
+                        self.col_to_expr(self.overwrite_condition, session)
+                    )
+            elif wm == "create_or_replace":
+                plan.write_operation_v2.mode = proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE
+            else:
+                raise ValueError(f"Unknown Mode value for DataFrame: {self.mode}")
+        return plan
+
+
 # Catalog API (internal-only)
 
 
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py
index bc52ce6e763..482e7dd4fcc 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -30,12 +30,13 @@ _sym_db = _symbol_database.Default()
 
 
 from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
+from pyspark.sql.connect.proto import expressions_pb2 as spark_dot_connect_dot_expressions__pb2
 from pyspark.sql.connect.proto import relations_pb2 as spark_dot_connect_dot_relations__pb2
 from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__pb2
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xcb\x02\n\x07\x43ommand\x12N\n\x0f\x63reate_function\x18\x01 \x01(\x0b\x32#.spark.connect.CreateScalarFunctionH\x00R\x0e\x63reateFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x01(\x0b\x32).spark.connect.CreateDataFra [...]
+    b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\x9c\x03\n\x07\x43ommand\x12N\n\x0f\x63reate_function\x18\x01 \x01(\x0b\x32#.spark.connect.CreateScalarFunctionH\x00R\x0e\x63reateFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x18\x03 \x0 [...]
 )
 
 
@@ -45,10 +46,16 @@ _CREATEDATAFRAMEVIEWCOMMAND = DESCRIPTOR.message_types_by_name["CreateDataFrameV
 _WRITEOPERATION = DESCRIPTOR.message_types_by_name["WriteOperation"]
 _WRITEOPERATION_OPTIONSENTRY = _WRITEOPERATION.nested_types_by_name["OptionsEntry"]
 _WRITEOPERATION_BUCKETBY = _WRITEOPERATION.nested_types_by_name["BucketBy"]
+_WRITEOPERATIONV2 = DESCRIPTOR.message_types_by_name["WriteOperationV2"]
+_WRITEOPERATIONV2_OPTIONSENTRY = _WRITEOPERATIONV2.nested_types_by_name["OptionsEntry"]
+_WRITEOPERATIONV2_TABLEPROPERTIESENTRY = _WRITEOPERATIONV2.nested_types_by_name[
+    "TablePropertiesEntry"
+]
 _CREATESCALARFUNCTION_FUNCTIONLANGUAGE = _CREATESCALARFUNCTION.enum_types_by_name[
     "FunctionLanguage"
 ]
 _WRITEOPERATION_SAVEMODE = _WRITEOPERATION.enum_types_by_name["SaveMode"]
+_WRITEOPERATIONV2_MODE = _WRITEOPERATIONV2.enum_types_by_name["Mode"]
 Command = _reflection.GeneratedProtocolMessageType(
     "Command",
     (_message.Message,),
@@ -113,26 +120,69 @@ _sym_db.RegisterMessage(WriteOperation)
 _sym_db.RegisterMessage(WriteOperation.OptionsEntry)
 _sym_db.RegisterMessage(WriteOperation.BucketBy)
 
+WriteOperationV2 = _reflection.GeneratedProtocolMessageType(
+    "WriteOperationV2",
+    (_message.Message,),
+    {
+        "OptionsEntry": _reflection.GeneratedProtocolMessageType(
+            "OptionsEntry",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _WRITEOPERATIONV2_OPTIONSENTRY,
+                "__module__": "spark.connect.commands_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.WriteOperationV2.OptionsEntry)
+            },
+        ),
+        "TablePropertiesEntry": _reflection.GeneratedProtocolMessageType(
+            "TablePropertiesEntry",
+            (_message.Message,),
+            {
+                "DESCRIPTOR": _WRITEOPERATIONV2_TABLEPROPERTIESENTRY,
+                "__module__": "spark.connect.commands_pb2"
+                # @@protoc_insertion_point(class_scope:spark.connect.WriteOperationV2.TablePropertiesEntry)
+            },
+        ),
+        "DESCRIPTOR": _WRITEOPERATIONV2,
+        "__module__": "spark.connect.commands_pb2"
+        # @@protoc_insertion_point(class_scope:spark.connect.WriteOperationV2)
+    },
+)
+_sym_db.RegisterMessage(WriteOperationV2)
+_sym_db.RegisterMessage(WriteOperationV2.OptionsEntry)
+_sym_db.RegisterMessage(WriteOperationV2.TablePropertiesEntry)
+
 if _descriptor._USE_C_DESCRIPTORS == False:
 
     DESCRIPTOR._options = None
     DESCRIPTOR._serialized_options = b"\n\036org.apache.spark.connect.protoP\001"
     _WRITEOPERATION_OPTIONSENTRY._options = None
     _WRITEOPERATION_OPTIONSENTRY._serialized_options = b"8\001"
-    _COMMAND._serialized_start = 133
-    _COMMAND._serialized_end = 464
-    _CREATESCALARFUNCTION._serialized_start = 467
-    _CREATESCALARFUNCTION._serialized_end = 1002
-    _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_start = 840
-    _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_end = 979
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1005
-    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1155
-    _WRITEOPERATION._serialized_start = 1158
-    _WRITEOPERATION._serialized_end = 1900
-    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1596
-    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1654
-    _WRITEOPERATION_BUCKETBY._serialized_start = 1656
-    _WRITEOPERATION_BUCKETBY._serialized_end = 1747
-    _WRITEOPERATION_SAVEMODE._serialized_start = 1750
-    _WRITEOPERATION_SAVEMODE._serialized_end = 1887
+    _WRITEOPERATIONV2_OPTIONSENTRY._options = None
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_options = b"8\001"
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._options = None
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_options = b"8\001"
+    _COMMAND._serialized_start = 166
+    _COMMAND._serialized_end = 578
+    _CREATESCALARFUNCTION._serialized_start = 581
+    _CREATESCALARFUNCTION._serialized_end = 1116
+    _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_start = 954
+    _CREATESCALARFUNCTION_FUNCTIONLANGUAGE._serialized_end = 1093
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 1119
+    _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 1269
+    _WRITEOPERATION._serialized_start = 1272
+    _WRITEOPERATION._serialized_end = 2014
+    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1710
+    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1768
+    _WRITEOPERATION_BUCKETBY._serialized_start = 1770
+    _WRITEOPERATION_BUCKETBY._serialized_end = 1861
+    _WRITEOPERATION_SAVEMODE._serialized_start = 1864
+    _WRITEOPERATION_SAVEMODE._serialized_end = 2001
+    _WRITEOPERATIONV2._serialized_start = 2017
+    _WRITEOPERATIONV2._serialized_end = 2812
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1710
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1768
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2584
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2650
+    _WRITEOPERATIONV2_MODE._serialized_start = 2653
+    _WRITEOPERATIONV2_MODE._serialized_end = 2812
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 2cebbf47c23..2447be6b467 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -40,6 +40,7 @@ import google.protobuf.descriptor
 import google.protobuf.internal.containers
 import google.protobuf.internal.enum_type_wrapper
 import google.protobuf.message
+import pyspark.sql.connect.proto.expressions_pb2
 import pyspark.sql.connect.proto.relations_pb2
 import pyspark.sql.connect.proto.types_pb2
 import sys
@@ -62,6 +63,7 @@ class Command(google.protobuf.message.Message):
     CREATE_FUNCTION_FIELD_NUMBER: builtins.int
     WRITE_OPERATION_FIELD_NUMBER: builtins.int
     CREATE_DATAFRAME_VIEW_FIELD_NUMBER: builtins.int
+    WRITE_OPERATION_V2_FIELD_NUMBER: builtins.int
     EXTENSION_FIELD_NUMBER: builtins.int
     @property
     def create_function(self) -> global___CreateScalarFunction: ...
@@ -70,6 +72,8 @@ class Command(google.protobuf.message.Message):
     @property
     def create_dataframe_view(self) -> global___CreateDataFrameViewCommand: ...
     @property
+    def write_operation_v2(self) -> global___WriteOperationV2: ...
+    @property
     def extension(self) -> google.protobuf.any_pb2.Any:
         """This field is used to mark extensions to the protocol. When plugins generate arbitrary
         Commands they can add them here. During the planning the correct resolution is done.
@@ -80,6 +84,7 @@ class Command(google.protobuf.message.Message):
         create_function: global___CreateScalarFunction | None = ...,
         write_operation: global___WriteOperation | None = ...,
         create_dataframe_view: global___CreateDataFrameViewCommand | None = ...,
+        write_operation_v2: global___WriteOperationV2 | None = ...,
         extension: google.protobuf.any_pb2.Any | None = ...,
     ) -> None: ...
     def HasField(
@@ -95,6 +100,8 @@ class Command(google.protobuf.message.Message):
             b"extension",
             "write_operation",
             b"write_operation",
+            "write_operation_v2",
+            b"write_operation_v2",
         ],
     ) -> builtins.bool: ...
     def ClearField(
@@ -110,12 +117,18 @@ class Command(google.protobuf.message.Message):
             b"extension",
             "write_operation",
             b"write_operation",
+            "write_operation_v2",
+            b"write_operation_v2",
         ],
     ) -> None: ...
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["command_type", b"command_type"]
     ) -> typing_extensions.Literal[
-        "create_function", "write_operation", "create_dataframe_view", "extension"
+        "create_function",
+        "write_operation",
+        "create_dataframe_view",
+        "write_operation_v2",
+        "extension",
     ] | None: ...
 
 global___Command = Command
@@ -441,3 +454,153 @@ class WriteOperation(google.protobuf.message.Message):
     ) -> typing_extensions.Literal["path", "table_name"] | None: ...
 
 global___WriteOperation = WriteOperation
+
+class WriteOperationV2(google.protobuf.message.Message):
+    """As writes are not directly handled during analysis and planning, they are modeled as commands."""
+
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class _Mode:
+        ValueType = typing.NewType("ValueType", builtins.int)
+        V: typing_extensions.TypeAlias = ValueType
+
+    class _ModeEnumTypeWrapper(
+        google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[
+            WriteOperationV2._Mode.ValueType
+        ],
+        builtins.type,
+    ):  # noqa: F821
+        DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor
+        MODE_UNSPECIFIED: WriteOperationV2._Mode.ValueType  # 0
+        MODE_CREATE: WriteOperationV2._Mode.ValueType  # 1
+        MODE_OVERWRITE: WriteOperationV2._Mode.ValueType  # 2
+        MODE_OVERWRITE_PARTITIONS: WriteOperationV2._Mode.ValueType  # 3
+        MODE_APPEND: WriteOperationV2._Mode.ValueType  # 4
+        MODE_REPLACE: WriteOperationV2._Mode.ValueType  # 5
+        MODE_CREATE_OR_REPLACE: WriteOperationV2._Mode.ValueType  # 6
+
+    class Mode(_Mode, metaclass=_ModeEnumTypeWrapper): ...
+    MODE_UNSPECIFIED: WriteOperationV2.Mode.ValueType  # 0
+    MODE_CREATE: WriteOperationV2.Mode.ValueType  # 1
+    MODE_OVERWRITE: WriteOperationV2.Mode.ValueType  # 2
+    MODE_OVERWRITE_PARTITIONS: WriteOperationV2.Mode.ValueType  # 3
+    MODE_APPEND: WriteOperationV2.Mode.ValueType  # 4
+    MODE_REPLACE: WriteOperationV2.Mode.ValueType  # 5
+    MODE_CREATE_OR_REPLACE: WriteOperationV2.Mode.ValueType  # 6
+
+    class OptionsEntry(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        KEY_FIELD_NUMBER: builtins.int
+        VALUE_FIELD_NUMBER: builtins.int
+        key: builtins.str
+        value: builtins.str
+        def __init__(
+            self,
+            *,
+            key: builtins.str = ...,
+            value: builtins.str = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+        ) -> None: ...
+
+    class TablePropertiesEntry(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        KEY_FIELD_NUMBER: builtins.int
+        VALUE_FIELD_NUMBER: builtins.int
+        key: builtins.str
+        value: builtins.str
+        def __init__(
+            self,
+            *,
+            key: builtins.str = ...,
+            value: builtins.str = ...,
+        ) -> None: ...
+        def ClearField(
+            self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]
+        ) -> None: ...
+
+    INPUT_FIELD_NUMBER: builtins.int
+    TABLE_NAME_FIELD_NUMBER: builtins.int
+    PROVIDER_FIELD_NUMBER: builtins.int
+    PARTITIONING_COLUMNS_FIELD_NUMBER: builtins.int
+    OPTIONS_FIELD_NUMBER: builtins.int
+    TABLE_PROPERTIES_FIELD_NUMBER: builtins.int
+    MODE_FIELD_NUMBER: builtins.int
+    OVERWRITE_CONDITION_FIELD_NUMBER: builtins.int
+    @property
+    def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
+        """(Required) The output of the `input` relation will be persisted according to the options."""
+    table_name: builtins.str
+    """The destination of the write operation must be either a path or a table."""
+    provider: builtins.str
+    """A provider for the underlying output data source. Spark's default catalog supports
+    "parquet", "json", etc.
+    """
+    @property
+    def partitioning_columns(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        pyspark.sql.connect.proto.expressions_pb2.Expression
+    ]:
+        """(Optional) List of columns for partitioning for output table created by `create`,
+        `createOrReplace`, or `replace`
+        """
+    @property
+    def options(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+        """(Optional) A list of configuration options."""
+    @property
+    def table_properties(
+        self,
+    ) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]:
+        """(Optional) A list of table properties."""
+    mode: global___WriteOperationV2.Mode.ValueType
+    @property
+    def overwrite_condition(self) -> pyspark.sql.connect.proto.expressions_pb2.Expression:
+        """(Optional) A condition for overwrite saving mode"""
+    def __init__(
+        self,
+        *,
+        input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
+        table_name: builtins.str = ...,
+        provider: builtins.str = ...,
+        partitioning_columns: collections.abc.Iterable[
+            pyspark.sql.connect.proto.expressions_pb2.Expression
+        ]
+        | None = ...,
+        options: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+        table_properties: collections.abc.Mapping[builtins.str, builtins.str] | None = ...,
+        mode: global___WriteOperationV2.Mode.ValueType = ...,
+        overwrite_condition: pyspark.sql.connect.proto.expressions_pb2.Expression | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "input", b"input", "overwrite_condition", b"overwrite_condition"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "input",
+            b"input",
+            "mode",
+            b"mode",
+            "options",
+            b"options",
+            "overwrite_condition",
+            b"overwrite_condition",
+            "partitioning_columns",
+            b"partitioning_columns",
+            "provider",
+            b"provider",
+            "table_name",
+            b"table_name",
+            "table_properties",
+            b"table_properties",
+        ],
+    ) -> None: ...
+
+global___WriteOperationV2 = WriteOperationV2
diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py
index f7a3fc1344c..8724348592e 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -22,17 +22,18 @@ from typing import Dict
 from typing import Optional, Union, List, overload, Tuple, cast, Any
 from typing import TYPE_CHECKING
 
-from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation
+from pyspark.sql.connect.plan import Read, DataSource, LogicalPlan, WriteOperation, WriteOperationV2
 from pyspark.sql.types import StructType
 from pyspark.sql.utils import to_str
 from pyspark.sql.readwriter import (
     DataFrameWriter as PySparkDataFrameWriter,
     DataFrameReader as PySparkDataFrameReader,
+    DataFrameWriterV2 as PySparkDataFrameWriterV2,
 )
 
 if TYPE_CHECKING:
     from pyspark.sql.connect.dataframe import DataFrame
-    from pyspark.sql.connect._typing import OptionalPrimitiveType
+    from pyspark.sql.connect._typing import ColumnOrName, OptionalPrimitiveType
     from pyspark.sql.connect.session import SparkSession
 
 __all__ = ["DataFrameReader", "DataFrameWriter"]
@@ -603,6 +604,83 @@ class DataFrameWriter(OptionUtils):
         raise NotImplementedError("jdbc() not supported for DataFrameWriter")
 
 
+class DataFrameWriterV2(OptionUtils):
+    def __init__(self, plan: "LogicalPlan", session: "SparkSession", table: str):
+        self._df: "LogicalPlan" = plan
+        self._spark: "SparkSession" = session
+        self._table_name: str = table
+        self._write: "WriteOperationV2" = WriteOperationV2(self._df, self._table_name)
+
+    def using(self, provider: str) -> "DataFrameWriterV2":
+        self._write.provider = provider
+        return self
+
+    using.__doc__ = PySparkDataFrameWriterV2.using.__doc__
+
+    def option(self, key: str, value: "OptionalPrimitiveType") -> "DataFrameWriterV2":
+        self._write.options[key] = to_str(value)
+        return self
+
+    option.__doc__ = PySparkDataFrameWriterV2.option.__doc__
+
+    def options(self, **options: "OptionalPrimitiveType") -> "DataFrameWriterV2":
+        for k in options:
+            self._write.options[k] = to_str(options[k])
+        return self
+
+    options.__doc__ = PySparkDataFrameWriterV2.options.__doc__
+
+    def tableProperty(self, property: str, value: str) -> "DataFrameWriterV2":
+        self._write.table_properties[property] = value
+        return self
+
+    tableProperty.__doc__ = PySparkDataFrameWriterV2.tableProperty.__doc__
+
+    def partitionedBy(self, col: "ColumnOrName", *cols: "ColumnOrName") -> "DataFrameWriterV2":
+        self._write.partitioning_columns = [col]
+        self._write.partitioning_columns.extend(cols)
+        return self
+
+    partitionedBy.__doc__ = PySparkDataFrameWriterV2.partitionedBy.__doc__
+
+    def create(self) -> None:
+        self._write.mode = "create"
+        self._spark.client.execute_command(self._write.command(self._spark.client))
+
+    create.__doc__ = PySparkDataFrameWriterV2.create.__doc__
+
+    def replace(self) -> None:
+        self._write.mode = "replace"
+        self._spark.client.execute_command(self._write.command(self._spark.client))
+
+    replace.__doc__ = PySparkDataFrameWriterV2.replace.__doc__
+
+    def createOrReplace(self) -> None:
+        self._write.mode = "create_or_replace"
+        self._spark.client.execute_command(self._write.command(self._spark.client))
+
+    createOrReplace.__doc__ = PySparkDataFrameWriterV2.createOrReplace.__doc__
+
+    def append(self) -> None:
+        self._write.mode = "append"
+        self._spark.client.execute_command(self._write.command(self._spark.client))
+
+    append.__doc__ = PySparkDataFrameWriterV2.append.__doc__
+
+    def overwrite(self, condition: "ColumnOrName") -> None:
+        self._write.mode = "overwrite"
+        self._write.overwrite_condition = condition
+        self._spark.client.execute_command(self._write.command(self._spark.client))
+
+    overwrite.__doc__ = PySparkDataFrameWriterV2.overwrite.__doc__
+
+    def overwritePartitions(self) -> None:
+        self._write.mode = "overwrite_partitions"
+        self._spark.client.execute_command(self._write.command(self._spark.client))
+
+    overwritePartitions.__doc__ = PySparkDataFrameWriterV2.overwritePartitions.__doc__
+
+
 def _test() -> None:
     import sys
     import doctest
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index f9d9c199faf..77a35ca8240 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -5142,6 +5142,9 @@ class DataFrame(PandasMapOpsMixin, PandasConversionMixin):
 
         .. versionadded:: 3.1.0
 
+        .. versionchanged:: 3.4.0
+            Support Spark Connect.
+
         Parameters
         ----------
         table : str
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 8b083ae9054..b87fb6528bb 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -1921,6 +1921,9 @@ class DataFrameWriterV2:
     to external storage using the v2 API.
 
     .. versionadded:: 3.1.0
+
+    .. versionchanged:: 3.4.0
+        Support Spark Connect.
     """
 
     def __init__(self, df: "DataFrame", table: str):
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 08fad856036..8328f3181a3 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -62,6 +62,7 @@ if should_test_connect:
     from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
     from pyspark.sql.connect.client import ChannelBuilder
     from pyspark.sql.connect.column import Column
+    from pyspark.sql.connect.readwriter import DataFrameWriterV2
     from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
     from pyspark.sql.connect.function_builder import udf
@@ -1982,6 +1983,27 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         ndf = self.connect.read.table("parquet_test")
         self.assertEqual(set(df.collect()), set(ndf.collect()))
 
+    def test_writeTo_operations(self):
+        # SPARK-42002: Implement DataFrameWriterV2
+        import datetime
+        from pyspark.sql.connect.functions import col, years, months, days, hours, bucket
+
+        df = self.connect.createDataFrame(
+            [(1, datetime.datetime(2000, 1, 1), "foo")], ("id", "ts", "value")
+        )
+        writer = df.writeTo("table1")
+        self.assertIsInstance(writer.option("property", "value"), DataFrameWriterV2)
+        self.assertIsInstance(writer.options(property="value"), DataFrameWriterV2)
+        self.assertIsInstance(writer.using("source"), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(col("id")), DataFrameWriterV2)
+        self.assertIsInstance(writer.tableProperty("foo", "bar"), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(years("ts")), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(months("ts")), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(days("ts")), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(hours("ts")), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(bucket(11, "id")), DataFrameWriterV2)
+        self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours("ts")), DataFrameWriterV2)
+
     def test_agg_with_avg(self):
         # SPARK-41325: groupby.avg()
         df = (
@@ -2629,7 +2651,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             "_repr_html_",
             "semanticHash",
             "sameSemantics",
-            "writeTo",
         ):
             with self.assertRaises(NotImplementedError):
                 getattr(df, f)()
@@ -2681,7 +2702,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
 
     def test_unsupported_io_functions(self):
         # SPARK-41964: Disable unsupported functions.
-        # DataFrameWriterV2 is also not implemented yet
         df = self.connect.createDataFrame([(x, f"{x}") for x in range(100)], ["id", "name"])
 
         for f in ("jdbc",):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org