You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2023/02/22 20:51:45 UTC

[spark] branch master updated: [SPARK-42518][CONNECT] Scala Client DataFrameWriterV2

This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0c4645eb6bb [SPARK-42518][CONNECT] Scala Client DataFrameWriterV2
0c4645eb6bb is described below

commit 0c4645eb6bb4740b92281d124053d4610090da34
Author: Zhen Li <zh...@users.noreply.github.com>
AuthorDate: Wed Feb 22 16:51:23 2023 -0400

    [SPARK-42518][CONNECT] Scala Client DataFrameWriterV2
    
    ### What changes were proposed in this pull request?
    Adding DataFrameWriterV2. This allows users to use the Dataset#writeTo API.
    
    ### Why are the changes needed?
    Impls Dataset#writeTo
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    E2E
    
    This is based on https://github.com/apache/spark/pull/40061
    
    Closes #40075 from zhenlineo/write-v2.
    
    Authored-by: Zhen Li <zh...@users.noreply.github.com>
    Signed-off-by: Herman van Hovell <he...@databricks.com>
---
 .../org/apache/spark/sql/DataFrameWriterV2.scala   | 289 +++++++++++++++++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  23 ++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  43 ++-
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  29 +++
 .../sql/connect/client/CompatibilitySuite.scala    |  11 +-
 .../connect/client/util/RemoteSparkSession.scala   |  12 +
 6 files changed, 399 insertions(+), 8 deletions(-)

diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
new file mode 100644
index 00000000000..ed149223129
--- /dev/null
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameWriterV2.scala
@@ -0,0 +1,289 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.connect.proto
+
+/**
+ * Interface used to write a [[org.apache.spark.sql.Dataset]] to external storage using the v2
+ * API.
+ *
+ * @since 3.4.0
+ */
+@Experimental
+final class DataFrameWriterV2[T] private[sql] (table: String, ds: Dataset[T])
+    extends CreateTableWriter[T] {
+
+  private var provider: Option[String] = None
+
+  private val options = new mutable.HashMap[String, String]()
+
+  private val properties = new mutable.HashMap[String, String]()
+
+  private var partitioning: Option[Seq[proto.Expression]] = None
+
+  private var overwriteCondition: Option[proto.Expression] = None
+
+  override def using(provider: String): CreateTableWriter[T] = {
+    this.provider = Some(provider)
+    this
+  }
+
+  override def option(key: String, value: String): DataFrameWriterV2[T] = {
+    this.options.put(key, value)
+    this
+  }
+
+  override def options(options: scala.collection.Map[String, String]): DataFrameWriterV2[T] = {
+    options.foreach { case (key, value) =>
+      this.options.put(key, value)
+    }
+    this
+  }
+
+  override def options(options: java.util.Map[String, String]): DataFrameWriterV2[T] = {
+    this.options(options.asScala)
+    this
+  }
+
+  override def tableProperty(property: String, value: String): CreateTableWriter[T] = {
+    this.properties.put(property, value)
+    this
+  }
+
+  @scala.annotation.varargs
+  override def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T] = {
+    val asTransforms = (column +: columns).map(_.expr)
+    this.partitioning = Some(asTransforms)
+    this
+  }
+
+  override def create(): Unit = {
+    executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE)
+  }
+
+  override def replace(): Unit = {
+    executeWriteOperation(proto.WriteOperationV2.Mode.MODE_REPLACE)
+  }
+
+  override def createOrReplace(): Unit = {
+    executeWriteOperation(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
+  }
+
+  /**
+   * Append the contents of the data frame to the output table.
+   *
+   * If the output table does not exist, this operation will fail with
+   * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+   * validated to ensure it is compatible with the existing table.
+   *
+   * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+   *   If the table does not exist
+   */
+  def append(): Unit = {
+    executeWriteOperation(proto.WriteOperationV2.Mode.MODE_APPEND)
+  }
+
+  /**
+   * Overwrite rows matching the given filter condition with the contents of the data frame in the
+   * output table.
+   *
+   * If the output table does not exist, this operation will fail with
+   * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+   * validated to ensure it is compatible with the existing table.
+   *
+   * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+   *   If the table does not exist
+   */
+  def overwrite(condition: Column): Unit = {
+    overwriteCondition = Some(condition.expr)
+    executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE)
+  }
+
+  /**
+   * Overwrite all partition for which the data frame contains at least one row with the contents
+   * of the data frame in the output table.
+   *
+   * This operation is equivalent to Hive's `INSERT OVERWRITE ... PARTITION`, which replaces
+   * partitions dynamically depending on the contents of the data frame.
+   *
+   * If the output table does not exist, this operation will fail with
+   * [[org.apache.spark.sql.catalyst.analysis.NoSuchTableException]]. The data frame will be
+   * validated to ensure it is compatible with the existing table.
+   *
+   * @throws org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+   *   If the table does not exist
+   */
+  def overwritePartitions(): Unit = {
+    executeWriteOperation(proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS)
+  }
+
+  private def executeWriteOperation(mode: proto.WriteOperationV2.Mode): Unit = {
+    val builder = proto.WriteOperationV2.newBuilder()
+
+    builder.setInput(ds.plan.getRoot)
+    builder.setTableName(table)
+    provider.foreach(builder.setProvider)
+
+    partitioning.foreach(columns => builder.addAllPartitioningColumns(columns.asJava))
+
+    options.foreach { case (k, v) =>
+      builder.putOptions(k, v)
+    }
+    properties.foreach { case (k, v) =>
+      builder.putTableProperties(k, v)
+    }
+
+    builder.setMode(mode)
+
+    overwriteCondition.foreach(builder.setOverwriteCondition)
+
+    ds.session.execute(proto.Command.newBuilder().setWriteOperationV2(builder).build())
+  }
+}
+
+/**
+ * Configuration methods common to create/replace operations and insert/overwrite operations.
+ * @tparam R
+ *   builder type to return
+ * @since 3.4.0
+ */
+trait WriteConfigMethods[R] {
+
+  /**
+   * Add a write option.
+   *
+   * @since 3.4.0
+   */
+  def option(key: String, value: String): R
+
+  /**
+   * Add a boolean output option.
+   *
+   * @since 3.4.0
+   */
+  def option(key: String, value: Boolean): R = option(key, value.toString)
+
+  /**
+   * Add a long output option.
+   *
+   * @since 3.4.0
+   */
+  def option(key: String, value: Long): R = option(key, value.toString)
+
+  /**
+   * Add a double output option.
+   *
+   * @since 3.4.0
+   */
+  def option(key: String, value: Double): R = option(key, value.toString)
+
+  /**
+   * Add write options from a Scala Map.
+   *
+   * @since 3.4.0
+   */
+  def options(options: scala.collection.Map[String, String]): R
+
+  /**
+   * Add write options from a Java Map.
+   *
+   * @since 3.4.0
+   */
+  def options(options: java.util.Map[String, String]): R
+}
+
+/**
+ * Trait to restrict calls to create and replace operations.
+ *
+ * @since 3.4.0
+ */
+trait CreateTableWriter[T] extends WriteConfigMethods[CreateTableWriter[T]] {
+
+  /**
+   * Create a new table from the contents of the data frame.
+   *
+   * The new table's schema, partition layout, properties, and other configuration will be based
+   * on the configuration set on this writer.
+   *
+   * If the output table exists, this operation will fail with
+   * [[org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException]].
+   *
+   * @throws org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
+   *   If the table already exists
+   */
+  def create(): Unit
+
+  /**
+   * Replace an existing table with the contents of the data frame.
+   *
+   * The existing table's schema, partition layout, properties, and other configuration will be
+   * replaced with the contents of the data frame and the configuration set on this writer.
+   *
+   * If the output table does not exist, this operation will fail with
+   * [[org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException]].
+   *
+   * @throws org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
+   *   If the table does not exist
+   */
+  def replace(): Unit
+
+  /**
+   * Create a new table or replace an existing table with the contents of the data frame.
+   *
+   * The output table's schema, partition layout, properties, and other configuration will be
+   * based on the contents of the data frame and the configuration set on this writer. If the
+   * table exists, its configuration and data will be replaced.
+   */
+  def createOrReplace(): Unit
+
+  /**
+   * Partition the output table created by `create`, `createOrReplace`, or `replace` using the
+   * given columns or transforms.
+   *
+   * When specified, the table data will be stored by these values for efficient reads.
+   *
+   * For example, when a table is partitioned by day, it may be stored in a directory layout like:
+   * <ul> <li>`table/day=2019-06-01/`</li> <li>`table/day=2019-06-02/`</li> </ul>
+   *
+   * Partitioning is one of the most widely used techniques to optimize physical data layout. It
+   * provides a coarse-grained index for skipping unnecessary data reads when queries have
+   * predicates on the partitioned columns. In order for partitioning to work well, the number of
+   * distinct values in each column should typically be less than tens of thousands.
+   *
+   * @since 3.4.0
+   */
+  def partitionedBy(column: Column, columns: Column*): CreateTableWriter[T]
+
+  /**
+   * Specifies a provider for the underlying output data source. Spark's default catalog supports
+   * "parquet", "json", etc.
+   *
+   * @since 3.4.0
+   */
+  def using(provider: String): CreateTableWriter[T]
+
+  /**
+   * Add a table property.
+   */
+  def tableProperty(property: String, value: String): CreateTableWriter[T]
+}
diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
index 3c876c05432..33125e5fd87 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2256,6 +2256,29 @@ class Dataset[T] private[sql] (val session: SparkSession, private[sql] val plan:
     new DataFrameWriter[T](this)
   }
 
+  /**
+   * Create a write configuration builder for v2 sources.
+   *
+   * This builder is used to configure and execute write operations. For example, to append to an
+   * existing table, run:
+   *
+   * {{{
+   *   df.writeTo("catalog.db.table").append()
+   * }}}
+   *
+   * This can also be used to create or replace existing tables:
+   *
+   * {{{
+   *   df.writeTo("catalog.db.table").partitionedBy($"col").createOrReplace()
+   * }}}
+   *
+   * @group basic
+   * @since 3.4.0
+   */
+  def writeTo(table: String): DataFrameWriterV2[T] = {
+    new DataFrameWriterV2[T](table, this)
+  }
+
   private[sql] def analyze: proto.AnalyzePlanResponse = {
     session.analyze(plan, proto.Explain.ExplainMode.SIMPLE)
   }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 145d62feefc..e5d426e80f9 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -140,7 +140,7 @@ class ClientE2ETestSuite extends RemoteSparkSession {
   }
 
   test("write table") {
-    try {
+    withTable("myTable") {
       val df = spark.range(10).limit(3)
       df.write.mode(SaveMode.Overwrite).saveAsTable("myTable")
       spark.range(2).write.insertInto("myTable")
@@ -151,8 +151,45 @@ class ClientE2ETestSuite extends RemoteSparkSession {
       assert(result(2).getLong(0) == 1)
       assert(result(3).getLong(0) == 1)
       assert(result(4).getLong(0) == 2)
-    } finally {
-      spark.sql("drop table if exists myTable").collect()
+    }
+  }
+
+  test("writeTo with create and using") {
+    // TODO (SPARK-42519): Add more test after we can set configs. See more WriteTo test cases
+    //  in SparkConnectProtoSuite.
+    //  e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+    withTable("myTableV2") {
+      spark.range(3).writeTo("myTableV2").using("parquet").create()
+      val result = spark.sql("select * from myTableV2").sort("id").collect()
+      assert(result.length == 3)
+      assert(result(0).getLong(0) == 0)
+      assert(result(1).getLong(0) == 1)
+      assert(result(2).getLong(0) == 2)
+    }
+  }
+
+  // TODO (SPARK-42519): Revisit this test after we can set configs.
+  //  e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+  test("writeTo with create and append") {
+    withTable("myTableV2") {
+      spark.range(3).writeTo("myTableV2").using("parquet").create()
+      withTable("myTableV2") {
+        assertThrows[StatusRuntimeException] {
+          // Failed to append as Cannot write into v1 table: `spark_catalog`.`default`.`mytablev2`.
+          spark.range(3).writeTo("myTableV2").append()
+        }
+      }
+    }
+  }
+
+  // TODO (SPARK-42519): Revisit this test after we can set configs.
+  //  e.g. spark.conf.set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName)
+  test("writeTo with create") {
+    withTable("myTableV2") {
+      assertThrows[StatusRuntimeException] {
+        // Failed to create as Hive support is required.
+        spark.range(3).writeTo("myTableV2").create()
+      }
     }
   }
 
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 66e597f2457..412371c4186 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -26,6 +26,7 @@ import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
 
 import org.apache.spark.connect.proto
 import org.apache.spark.sql.connect.client.{DummySparkConnectService, SparkConnectClient}
+import org.apache.spark.sql.functions._
 
 // Add sample tests.
 // - sample fraction: simple.sample(0.1)
@@ -147,4 +148,32 @@ class DatasetSuite
     val actualPlan = service.getAndClearLatestInputPlan()
     assert(actualPlan.equals(expectedPlan))
   }
+
+  test("write V2") {
+    val df = ss.newDataset(_ => ()).limit(10)
+
+    val builder = proto.WriteOperationV2.newBuilder()
+    builder
+      .setInput(df.plan.getRoot)
+      .setTableName("t1")
+      .addPartitioningColumns(col("col99").expr)
+      .setProvider("json")
+      .putTableProperties("key", "value")
+      .putOptions("key2", "value2")
+      .setMode(proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE)
+
+    val expectedPlan = proto.Plan
+      .newBuilder()
+      .setCommand(proto.Command.newBuilder().setWriteOperationV2(builder))
+      .build()
+
+    df.writeTo("t1")
+      .partitionedBy(col("col99"))
+      .using("json")
+      .tableProperty("key", "value")
+      .options(Map("key2" -> "value2"))
+      .createOrReplace()
+    val actualPlan = service.getAndClearLatestInputPlan()
+    assert(actualPlan.equals(expectedPlan))
+  }
 }
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
index 81d58566cd9..010f3c616e6 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CompatibilitySuite.scala
@@ -75,8 +75,9 @@ class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite
       // TODO(SPARK-42175) Add the Dataset object definition
       // IncludeByName("org.apache.spark.sql.Dataset$"),
       IncludeByName("org.apache.spark.sql.DataFrame"),
-      IncludeByName("org.apache.spark.sql.DataFrameReader"),
-      IncludeByName("org.apache.spark.sql.DataFrameWriter"),
+      IncludeByName("org.apache.spark.sql.DataFrameReader.*"),
+      IncludeByName("org.apache.spark.sql.DataFrameWriter.*"),
+      IncludeByName("org.apache.spark.sql.DataFrameWriterV2.*"),
       IncludeByName("org.apache.spark.sql.SparkSession"),
       IncludeByName("org.apache.spark.sql.SparkSession$")) ++ includeImplementedMethods(clientJar)
     val excludeRules = Seq(
@@ -86,6 +87,8 @@ class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite
       // Deprecated json methods and RDD related methods are skipped in the client.
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.json"),
       ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.csv"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameReader.jdbc"),
+      ProblemFilters.exclude[Problem]("org.apache.spark.sql.DataFrameWriter.jdbc"),
       // Skip all shaded dependencies in the client.
       ProblemFilters.exclude[Problem]("org.sparkproject.*"),
       ProblemFilters.exclude[Problem]("org.apache.spark.connect.proto.*"),
@@ -135,9 +138,7 @@ class CompatibilitySuite extends AnyFunSuite { // scalastyle:ignore funsuite
       // TODO(SPARK-42175) Add all overloading methods. Temporarily mute compatibility check for \
       //  the Dataset methods, as too many overload methods are missing.
       // "org.apache.spark.sql.Dataset",
-      "org.apache.spark.sql.SparkSession",
-      "org.apache.spark.sql.DataFrameReader",
-      "org.apache.spark.sql.DataFrameWriter")
+      "org.apache.spark.sql.SparkSession")
 
     val clientClassLoader: URLClassLoader = new URLClassLoader(Seq(clientJar.toURI.toURL).toArray)
     clsNames
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
index 753e27efac3..50e3a51f759 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.connect.client.SparkConnectClient
 import org.apache.spark.sql.connect.client.util.IntegrationTestUtils._
 import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.util.Utils
 
 /**
  * An util class to start a local spark connect server in a different process for local E2E tests.
@@ -159,4 +160,15 @@ trait RemoteSparkSession
     spark = null
     super.afterAll()
   }
+
+  /**
+   * Drops table `tableName` after calling `f`.
+   */
+  protected def withTable(tableNames: String*)(f: => Unit): Unit = {
+    Utils.tryWithSafeFinally(f) {
+      tableNames.foreach { name =>
+        spark.sql(s"DROP TABLE IF EXISTS $name").collect()
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org