You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hu...@apache.org on 2022/08/15 17:58:39 UTC

[spark] branch master updated: [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite

This is an automated email from the ASF dual-hosted git repository.

huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1103343e71f [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite
1103343e71f is described below

commit 1103343e71fbcb478fa41941c87d2c28b0c09281
Author: huaxingao <hu...@apple.com>
AuthorDate: Mon Aug 15 10:58:14 2022 -0700

    [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite
    
    ### What changes were proposed in this pull request?
    Migrate `SupportsOverwrite` to use V2 Filter
    
    ### Why are the changes needed?
    this is part of the V2Filter migration work
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    add `SupportsOverwriteV2`
    
    ### How was this patch tested?
    new tests
    
    Closes #37502 from huaxingao/v2overwrite.
    
    Authored-by: huaxingao <hu...@apple.com>
    Signed-off-by: huaxingao <hu...@apple.com>
---
 .../sql/connector/catalog/TableCapability.java     |   2 +-
 .../connector/write/SupportsDynamicOverwrite.java  |   2 +-
 .../sql/connector/write/SupportsOverwrite.java     |  31 ++-
 ...ortsOverwrite.java => SupportsOverwriteV2.java} |  31 ++-
 .../sql/connector/catalog/InMemoryBaseTable.scala  | 138 +++---------
 .../sql/connector/catalog/InMemoryTable.scala      |  99 ++++++++-
 .../catalog/InMemoryTableWithV2Filter.scala        |  72 +++++--
 .../sql/execution/datasources/v2/V2Writes.scala    |  23 +-
 .../spark/sql/connector/DataSourceV2SQLSuite.scala | 233 +++------------------
 .../spark/sql/connector/DeleteFromTests.scala      | 132 ++++++++++++
 .../spark/sql/connector/V1WriteFallbackSuite.scala |   4 +-
 11 files changed, 412 insertions(+), 355 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java
index 5bb42fb4b31..5732c0f3af4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java
@@ -76,7 +76,7 @@ public enum TableCapability {
    * Signals that the table can replace existing data that matches a filter with appended data in
    * a write operation.
    * <p>
-   * See {@link org.apache.spark.sql.connector.write.SupportsOverwrite}.
+   * See {@link org.apache.spark.sql.connector.write.SupportsOverwriteV2}.
    */
   OVERWRITE_BY_FILTER,
 
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java
index 422cd71d345..0288a679891 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java
@@ -27,7 +27,7 @@ import org.apache.spark.annotation.Evolving;
  * write does not contain data will remain unchanged.
  * <p>
  * This is provided to implement SQL compatible with Hive table operations but is not recommended.
- * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data.
+ * Instead, use the {@link SupportsOverwriteV2 overwrite by filter API} to explicitly replace data.
  *
  * @since 3.0.0
  */
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
index b4e60257942..51bec236088 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
@@ -18,6 +18,8 @@
 package org.apache.spark.sql.connector.write;
 
 import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.sql.internal.connector.PredicateUtils;
 import org.apache.spark.sql.sources.AlwaysTrue$;
 import org.apache.spark.sql.sources.Filter;
 
@@ -30,7 +32,24 @@ import org.apache.spark.sql.sources.Filter;
  * @since 3.0.0
  */
 @Evolving
-public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
+public interface SupportsOverwrite extends SupportsOverwriteV2 {
+
+  /**
+   * Checks whether it is possible to overwrite data from a data source table that matches filter
+   * expressions.
+   * <p>
+   * Rows should be overwritten from the data source iff all of the filter expressions match.
+   * That is, the expressions must be interpreted as a set of filters that are ANDed together.
+   *
+   * @param filters V2 filter expressions, used to match data to overwrite
+   * @return true if the delete operation can be performed
+   *
+   * @since 3.4.0
+   */
+  default boolean canOverwrite(Filter[] filters) {
+    return true;
+  }
+
   /**
    * Configures a write to replace data matching the filters with data committed in the write.
    * <p>
@@ -42,6 +61,16 @@ public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
    */
   WriteBuilder overwrite(Filter[] filters);
 
+  default boolean canOverwrite(Predicate[] predicates) {
+    Filter[] v1Filters = PredicateUtils.toV1(predicates);
+    if (v1Filters.length < predicates.length) return false;
+    return this.canOverwrite(v1Filters);
+  }
+
+  default WriteBuilder overwrite(Predicate[] predicates) {
+    return this.overwrite(PredicateUtils.toV1(predicates));
+  }
+
   @Override
   default WriteBuilder truncate() {
     return overwrite(new Filter[] { AlwaysTrue$.MODULE$ });
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java
similarity index 60%
copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java
index b4e60257942..c1fcbfd38e1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java
@@ -18,8 +18,8 @@
 package org.apache.spark.sql.connector.write;
 
 import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.AlwaysTrue$;
-import org.apache.spark.sql.sources.Filter;
+import org.apache.spark.sql.connector.expressions.filter.AlwaysTrue;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
 
 /**
  * Write builder trait for tables that support overwrite by filter.
@@ -27,23 +27,40 @@ import org.apache.spark.sql.sources.Filter;
  * Overwriting data by filter will delete any data that matches the filter and replace it with data
  * that is committed in the write.
  *
- * @since 3.0.0
+ * @since 3.4.0
  */
 @Evolving
-public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
+public interface SupportsOverwriteV2 extends WriteBuilder, SupportsTruncate {
+
+  /**
+   * Checks whether it is possible to overwrite data from a data source table that matches filter
+   * expressions.
+   * <p>
+   * Rows should be overwritten from the data source iff all of the filter expressions match.
+   * That is, the expressions must be interpreted as a set of filters that are ANDed together.
+   *
+   * @param predicates V2 filter expressions, used to match data to overwrite
+   * @return true if the delete operation can be performed
+   *
+   * @since 3.4.0
+   */
+  default boolean canOverwrite(Predicate[] predicates) {
+    return true;
+  }
+
   /**
    * Configures a write to replace data matching the filters with data committed in the write.
    * <p>
    * Rows must be deleted from the data source if and only if all of the filters match. That is,
    * filters must be interpreted as ANDed together.
    *
-   * @param filters filters used to match data to overwrite
+   * @param predicates filters used to match data to overwrite
    * @return this write builder for method chaining
    */
-  WriteBuilder overwrite(Filter[] filters);
+  WriteBuilder overwrite(Predicate[] predicates);
 
   @Override
   default WriteBuilder truncate() {
-    return overwrite(new Filter[] { AlwaysTrue$.MODULE$ });
+    return overwrite(new Predicate[] { new AlwaysTrue() });
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 1f8b416cf55..f139399ed76 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -45,7 +45,7 @@ import org.apache.spark.unsafe.types.UTF8String
 /**
  * A simple in-memory table. Rows are stored as a buffered group produced by each output task.
  */
-class InMemoryBaseTable(
+abstract class InMemoryBaseTable(
     val name: String,
     val schema: StructType,
     override val partitioning: Array[Transform],
@@ -337,59 +337,39 @@ class InMemoryBaseTable(
     }
   }
 
-  override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
-    InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
-    InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+  abstract class InMemoryWriterBuilder() extends SupportsTruncate with SupportsDynamicOverwrite
+    with SupportsStreamingUpdateAsAppend {
 
-    new WriteBuilder with SupportsTruncate with SupportsOverwrite
-      with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend {
+    protected var writer: BatchWrite = Append
+    protected var streamingWriter: StreamingWrite = StreamingAppend
 
-      private var writer: BatchWrite = Append
-      private var streamingWriter: StreamingWrite = StreamingAppend
-
-      override def truncate(): WriteBuilder = {
-        assert(writer == Append)
-        writer = TruncateAndAppend
-        streamingWriter = StreamingTruncateAndAppend
-        this
-      }
-
-      override def overwrite(filters: Array[Filter]): WriteBuilder = {
-        assert(writer == Append)
-        writer = new Overwrite(filters)
-        streamingWriter = new StreamingNotSupportedOperation(
-          s"overwrite (${filters.mkString("filters(", ", ", ")")})")
-        this
-      }
-
-      override def overwriteDynamicPartitions(): WriteBuilder = {
-        assert(writer == Append)
-        writer = DynamicOverwrite
-        streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
-        this
-      }
+    override def overwriteDynamicPartitions(): WriteBuilder = {
+      assert(writer == Append)
+      writer = DynamicOverwrite
+      streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
+      this
+    }
 
-      override def build(): Write = new Write with RequiresDistributionAndOrdering {
-        override def requiredDistribution: Distribution = distribution
+    override def build(): Write = new Write with RequiresDistributionAndOrdering {
+      override def requiredDistribution: Distribution = distribution
 
-        override def distributionStrictlyRequired: Boolean = isDistributionStrictlyRequired
+      override def distributionStrictlyRequired: Boolean = isDistributionStrictlyRequired
 
-        override def requiredOrdering: Array[SortOrder] = ordering
+      override def requiredOrdering: Array[SortOrder] = ordering
 
-        override def requiredNumPartitions(): Int = {
-          numPartitions.getOrElse(0)
-        }
+      override def requiredNumPartitions(): Int = {
+        numPartitions.getOrElse(0)
+      }
 
-        override def toBatch: BatchWrite = writer
+      override def toBatch: BatchWrite = writer
 
-        override def toStreaming: StreamingWrite = streamingWriter match {
-          case exc: StreamingNotSupportedOperation => exc.throwsException()
-          case s => s
-        }
+      override def toStreaming: StreamingWrite = streamingWriter match {
+        case exc: StreamingNotSupportedOperation => exc.throwsException()
+        case s => s
+      }
 
-        override def supportedCustomMetrics(): Array[CustomMetric] = {
-          Array(new InMemorySimpleCustomMetric)
-        }
+      override def supportedCustomMetrics(): Array[CustomMetric] = {
+        Array(new InMemorySimpleCustomMetric)
       }
     }
   }
@@ -402,7 +382,7 @@ class InMemoryBaseTable(
     override def abort(messages: Array[WriterCommitMessage]): Unit = {}
   }
 
-  private object Append extends TestBatchWrite {
+  protected object Append extends TestBatchWrite {
     override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
       withData(messages.map(_.asInstanceOf[BufferedRows]))
     }
@@ -416,24 +396,14 @@ class InMemoryBaseTable(
     }
   }
 
-  private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
-    import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
-    override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
-      val deleteKeys = InMemoryBaseTable.filtersToKeys(
-        dataMap.keys, partCols.map(_.toSeq.quoted), filters)
-      dataMap --= deleteKeys
-      withData(messages.map(_.asInstanceOf[BufferedRows]))
-    }
-  }
-
-  private object TruncateAndAppend extends TestBatchWrite {
+  protected object TruncateAndAppend extends TestBatchWrite {
     override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
       dataMap.clear
       withData(messages.map(_.asInstanceOf[BufferedRows]))
     }
   }
 
-  private abstract class TestStreamingWrite extends StreamingWrite {
+  protected abstract class TestStreamingWrite extends StreamingWrite {
     def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
       BufferedRowsWriterFactory
     }
@@ -441,7 +411,7 @@ class InMemoryBaseTable(
     def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
   }
 
-  private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
+  protected class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
     override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
       throwsException()
 
@@ -463,7 +433,7 @@ class InMemoryBaseTable(
     }
   }
 
-  private object StreamingTruncateAndAppend extends TestStreamingWrite {
+  protected object StreamingTruncateAndAppend extends TestStreamingWrite {
     override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
       dataMap.synchronized {
         dataMap.clear
@@ -476,46 +446,7 @@ class InMemoryBaseTable(
 object InMemoryBaseTable {
   val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite"
 
-  def filtersToKeys(
-      keys: Iterable[Seq[Any]],
-      partitionNames: Seq[String],
-      filters: Array[Filter]): Iterable[Seq[Any]] = {
-    keys.filter { partValues =>
-      filters.flatMap(splitAnd).forall {
-        case EqualTo(attr, value) =>
-          value == extractValue(attr, partitionNames, partValues)
-        case EqualNullSafe(attr, value) =>
-          val attrVal = extractValue(attr, partitionNames, partValues)
-          if (attrVal == null && value === null) {
-            true
-          } else if (attrVal == null || value === null) {
-            false
-          } else {
-            value == attrVal
-          }
-        case IsNull(attr) =>
-          null == extractValue(attr, partitionNames, partValues)
-        case IsNotNull(attr) =>
-          null != extractValue(attr, partitionNames, partValues)
-        case AlwaysTrue() => true
-        case f =>
-          throw new IllegalArgumentException(s"Unsupported filter type: $f")
-      }
-    }
-  }
-
-  def supportsFilters(filters: Array[Filter]): Boolean = {
-    filters.flatMap(splitAnd).forall {
-      case _: EqualTo => true
-      case _: EqualNullSafe => true
-      case _: IsNull => true
-      case _: IsNotNull => true
-      case _: AlwaysTrue => true
-      case _ => false
-    }
-  }
-
-  private def extractValue(
+  def extractValue(
       attr: String,
       partFieldNames: Seq[String],
       partValues: Seq[Any]): Any = {
@@ -527,13 +458,6 @@ object InMemoryBaseTable {
     }
   }
 
-  private def splitAnd(filter: Filter): Seq[Filter] = {
-    filter match {
-      case And(left, right) => splitAnd(left) ++ splitAnd(right)
-      case _ => filter :: Nil
-    }
-  }
-
   def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = {
     if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) {
       throw new IllegalStateException("Manual write to table failure.")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index b82641a5d24..cd6821c8739 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -19,10 +19,14 @@ package org.apache.spark.sql.connector.catalog
 
 import java.util
 
+import org.scalatest.Assertions.assert
+
 import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
 import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
-import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
  * A simple in-memory table. Rows are stored as a buffered group produced by each output task.
@@ -40,12 +44,12 @@ class InMemoryTable(
     ordering, numPartitions, isDistributionStrictlyRequired) with SupportsDelete {
 
   override def canDeleteWhere(filters: Array[Filter]): Boolean = {
-    InMemoryBaseTable.supportsFilters(filters)
+    InMemoryTable.supportsFilters(filters)
   }
 
   override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
     import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
-    dataMap --= InMemoryBaseTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
+    dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
   }
 
   override def withData(data: Array[BufferedRows]): InMemoryTable = {
@@ -64,4 +68,93 @@ class InMemoryTable(
     })
     this
   }
+
+  override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+    InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
+    InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+
+    new InMemoryWriterBuilderWithOverWrite()
+  }
+
+  private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
+    with SupportsOverwrite {
+
+    override def truncate(): WriteBuilder = {
+      assert(writer == Append)
+      writer = TruncateAndAppend
+      streamingWriter = StreamingTruncateAndAppend
+      this
+    }
+
+    override def overwrite(filters: Array[Filter]): WriteBuilder = {
+      assert(writer == Append)
+      writer = new Overwrite(filters)
+      streamingWriter = new StreamingNotSupportedOperation(
+        s"overwrite (${filters.mkString("filters(", ", ", ")")})")
+      this
+    }
+
+    override def canOverwrite(filters: Array[Filter]): Boolean = {
+      InMemoryTable.supportsFilters(filters)
+    }
+  }
+
+  private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
+    import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+    override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+      val deleteKeys = InMemoryTable.filtersToKeys(
+        dataMap.keys, partCols.map(_.toSeq.quoted), filters)
+      dataMap --= deleteKeys
+      withData(messages.map(_.asInstanceOf[BufferedRows]))
+    }
+  }
+}
+
+object InMemoryTable {
+
+  def filtersToKeys(
+      keys: Iterable[Seq[Any]],
+      partitionNames: Seq[String],
+      filters: Array[Filter]): Iterable[Seq[Any]] = {
+    keys.filter { partValues =>
+      filters.flatMap(splitAnd).forall {
+        case EqualTo(attr, value) =>
+          value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+        case EqualNullSafe(attr, value) =>
+          val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+          if (attrVal == null && value == null) {
+            true
+          } else if (attrVal == null || value == null) {
+            false
+          } else {
+            value == attrVal
+          }
+        case IsNull(attr) =>
+          null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+        case IsNotNull(attr) =>
+          null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+        case AlwaysTrue() => true
+        case f =>
+          throw new IllegalArgumentException(s"Unsupported filter type: $f")
+      }
+    }
+  }
+
+  def supportsFilters(filters: Array[Filter]): Boolean = {
+    filters.flatMap(splitAnd).forall {
+      case _: EqualTo => true
+      case _: EqualNullSafe => true
+      case _: IsNull => true
+      case _: IsNotNull => true
+      case _: AlwaysTrue => true
+      case _ => false
+    }
+  }
+
+  private def splitAnd(filter: Filter): Seq[Filter] = {
+    filter match {
+      case And(left, right) => splitAnd(left) ++ splitAnd(right)
+      case _ => filter :: Nil
+    }
+  }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
index 48000dd0d98..b4285f31dd7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
@@ -19,9 +19,12 @@ package org.apache.spark.sql.connector.catalog
 
 import java.util
 
+import org.scalatest.Assertions.assert
+
 import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform}
 import org.apache.spark.sql.connector.expressions.filter.{And, Predicate}
 import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwriteV2, WriteBuilder, WriterCommitMessage}
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
@@ -32,8 +35,8 @@ class InMemoryTableWithV2Filter(
     properties: util.Map[String, String])
   extends InMemoryBaseTable(name, schema, partitioning, properties) with SupportsDeleteV2 {
 
-  override def canDeleteWhere(filters: Array[Predicate]): Boolean = {
-    InMemoryTableWithV2Filter.supportsFilters(filters)
+  override def canDeleteWhere(predicates: Array[Predicate]): Boolean = {
+    InMemoryTableWithV2Filter.supportsPredicates(predicates)
   }
 
   override def deleteWhere(filters: Array[Predicate]): Unit = dataMap.synchronized {
@@ -84,6 +87,46 @@ class InMemoryTableWithV2Filter(
       }
     }
   }
+
+  override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+    InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
+    InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+
+    new InMemoryWriterBuilderWithOverWrite()
+  }
+
+  private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
+    with SupportsOverwriteV2 {
+
+    override def truncate(): WriteBuilder = {
+      assert(writer == Append)
+      writer = TruncateAndAppend
+      streamingWriter = StreamingTruncateAndAppend
+      this
+    }
+
+    override def overwrite(predicates: Array[Predicate]): WriteBuilder = {
+      assert(writer == Append)
+      writer = new Overwrite(predicates)
+      streamingWriter = new StreamingNotSupportedOperation(
+        s"overwrite (${predicates.mkString("filters(", ", ", ")")})")
+      this
+    }
+
+    override def canOverwrite(predicates: Array[Predicate]): Boolean = {
+      InMemoryTableWithV2Filter.supportsPredicates(predicates)
+    }
+  }
+
+  private class Overwrite(predicates: Array[Predicate]) extends TestBatchWrite {
+    import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+    override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+      val deleteKeys = InMemoryTableWithV2Filter.filtersToKeys(
+        dataMap.keys, partCols.map(_.toSeq.quoted), predicates)
+      dataMap --= deleteKeys
+      withData(messages.map(_.asInstanceOf[BufferedRows]))
+    }
+  }
 }
 
 object InMemoryTableWithV2Filter {
@@ -96,9 +139,10 @@ object InMemoryTableWithV2Filter {
       filters.flatMap(splitAnd).forall {
         case p: Predicate if p.name().equals("=") =>
           p.children()(1).asInstanceOf[LiteralValue[_]].value ==
-            extractValue(p.children()(0).toString, partitionNames, partValues)
+            InMemoryBaseTable.extractValue(p.children()(0).toString, partitionNames, partValues)
         case p: Predicate if p.name().equals("<=>") =>
-          val attrVal = extractValue(p.children()(0).toString, partitionNames, partValues)
+          val attrVal = InMemoryBaseTable
+            .extractValue(p.children()(0).toString, partitionNames, partValues)
           val value = p.children()(1).asInstanceOf[LiteralValue[_]].value
           if (attrVal == null && value == null) {
             true
@@ -109,10 +153,10 @@ object InMemoryTableWithV2Filter {
           }
         case p: Predicate if p.name().equals("IS NULL") =>
           val attr = p.children()(0).toString
-          null == extractValue(attr, partitionNames, partValues)
+          null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
         case p: Predicate if p.name().equals("IS NOT NULL") =>
           val attr = p.children()(0).toString
-          null != extractValue(attr, partitionNames, partValues)
+          null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
         case p: Predicate if p.name().equals("ALWAYS_TRUE") => true
         case f =>
           throw new IllegalArgumentException(s"Unsupported filter type: $f")
@@ -120,8 +164,8 @@ object InMemoryTableWithV2Filter {
     }
   }
 
-  def supportsFilters(filters: Array[Predicate]): Boolean = {
-    filters.flatMap(splitAnd).forall {
+  def supportsPredicates(predicates: Array[Predicate]): Boolean = {
+    predicates.flatMap(splitAnd).forall {
       case p: Predicate if p.name().equals("=") => true
       case p: Predicate if p.name().equals("<=>") => true
       case p: Predicate if p.name().equals("IS NULL") => true
@@ -131,18 +175,6 @@ object InMemoryTableWithV2Filter {
     }
   }
 
-  private def extractValue(
-      attr: String,
-      partFieldNames: Seq[String],
-      partValues: Seq[Any]): Any = {
-    partFieldNames.zipWithIndex.find(_._1 == attr) match {
-      case Some((_, partIndex)) =>
-        partValues(partIndex)
-      case _ =>
-        throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
-    }
-  }
-
   private def splitAnd(filter: Predicate): Seq[Predicate] = {
     filter match {
       case and: And => splitAnd(and.left()) ++ splitAnd(and.right())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 4422743c5ac..2d47d94ff1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -24,12 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Ove
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
 import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
-import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, Write, WriteBuilder}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwriteV2, SupportsTruncate, Write, WriteBuilder}
 import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
 import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, WriteToMicroBatchDataSource}
 import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
-import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
 import org.apache.spark.sql.streaming.OutputMode
 import org.apache.spark.sql.types.StructType
 
@@ -49,21 +48,21 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
 
     case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) =>
       // fail if any filter cannot be converted. correctness depends on removing all matching data.
-      val filters = splitConjunctivePredicates(deleteExpr).flatMap { pred =>
-        val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true)
-        if (filter.isEmpty) {
+      val predicates = splitConjunctivePredicates(deleteExpr).flatMap { pred =>
+        val predicate = DataSourceV2Strategy.translateFilterV2(pred)
+        if (predicate.isEmpty) {
           throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred)
         }
-        filter
+        predicate
       }.toArray
 
       val table = r.table
       val writeBuilder = newWriteBuilder(table, options, query.schema)
       val write = writeBuilder match {
-        case builder: SupportsTruncate if isTruncate(filters) =>
+        case builder: SupportsTruncate if isTruncate(predicates) =>
           builder.truncate().build()
-        case builder: SupportsOverwrite =>
-          builder.overwrite(filters).build()
+        case builder: SupportsOverwriteV2 if builder.canOverwrite(predicates) =>
+          builder.overwrite(predicates).build()
         case _ =>
           throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table)
       }
@@ -123,8 +122,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
     }
   }
 
-  private def isTruncate(filters: Array[Filter]): Boolean = {
-    filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
+  private def isTruncate(predicates: Array[Predicate]): Boolean = {
+    predicates.length == 1 && predicates(0).name().equals("ALWAYS_TRUE")
   }
 
   private def newWriteBuilder(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 9ec5be46fc2..629a5ac83c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -40,18 +40,13 @@ import org.apache.spark.sql.sources.SimpleScanSource
 import org.apache.spark.sql.types.{LongType, MetadataBuilder, StringType, StructField, StructType}
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.Utils
 
-class DataSourceV2SQLSuite
+abstract class DataSourceV2SQLSuite
   extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true)
-  with AlterTableTests with DatasourceV2SQLBase {
+  with DeleteFromTests with DatasourceV2SQLBase {
 
-  import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-
-  private val v2Source = classOf[FakeV2Provider].getName
+  protected val v2Source = classOf[FakeV2Provider].getName
   override protected val v2Format = v2Source
-  override protected val catalogAndNamespace = "testcat.ns1.ns2."
-  private val defaultUser: String = Utils.getCurrentUserName()
 
   protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = {
     val tmpView = "tmp_view"
@@ -66,6 +61,20 @@ class DataSourceV2SQLSuite
     checkAnswer(spark.table(tableName), expected)
   }
 
+  protected def assertAnalysisError(
+      sqlStatement: String,
+      expectedError: String): Unit = {
+    val ex = intercept[AnalysisException] {
+      sql(sqlStatement)
+    }
+    assert(ex.getMessage.contains(expectedError))
+  }
+}
+
+class DataSourceV2SQLSuiteV1Filter extends DataSourceV2SQLSuite with AlterTableTests {
+  import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+  override protected val catalogAndNamespace = "testcat.ns1.ns2."
   override def getTableMetadata(tableName: String): Table = {
     val nameParts = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName)
     val v2Catalog = catalog(nameParts.head).asTableCatalog
@@ -622,8 +631,8 @@ class DataSourceV2SQLSuite
     assert(table.partitioning.isEmpty)
     assert(table.properties == withDefaultOwnership(Map("provider" -> v2Source)).asJava)
     assert(table.schema == new StructType()
-        .add("id", LongType)
-        .add("data", StringType))
+      .add("id", LongType)
+      .add("data", StringType))
 
     val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
     checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -639,8 +648,8 @@ class DataSourceV2SQLSuite
     assert(table.partitioning.isEmpty)
     assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
     assert(table.schema == new StructType()
-        .add("id", LongType)
-        .add("data", StringType))
+      .add("id", LongType)
+      .add("data", StringType))
 
     val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
     checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -659,8 +668,8 @@ class DataSourceV2SQLSuite
     assert(table2.partitioning.isEmpty)
     assert(table2.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
     assert(table2.schema == new StructType()
-        .add("id", LongType)
-        .add("data", StringType))
+      .add("id", LongType)
+      .add("data", StringType))
 
     val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
     checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source"))
@@ -677,8 +686,8 @@ class DataSourceV2SQLSuite
     assert(table.partitioning.isEmpty)
     assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
     assert(table.schema == new StructType()
-        .add("id", LongType)
-        .add("data", StringType))
+      .add("id", LongType)
+      .add("data", StringType))
 
     val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
     checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -708,8 +717,8 @@ class DataSourceV2SQLSuite
     assert(table.partitioning.isEmpty)
     assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
     assert(table.schema == new StructType()
-        .add("id", LongType)
-        .add("data", StringType))
+      .add("id", LongType)
+      .add("data", StringType))
 
     val rdd = sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
     checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -1526,148 +1535,6 @@ class DataSourceV2SQLSuite
     assert(e.message.contains("REPLACE TABLE is only supported with v2 tables"))
   }
 
-  test("DeleteFrom: basic - delete all") {
-    val t = "testcat.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t")
-      checkAnswer(spark.table(t), Seq())
-    }
-  }
-
-  test("DeleteFrom with v2 filtering: basic - delete all") {
-    val t = "testv2filter.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t")
-      checkAnswer(spark.table(t), Seq())
-    }
-  }
-
-  test("DeleteFrom: basic - delete with where clause") {
-    val t = "testcat.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t WHERE id = 2")
-      checkAnswer(spark.table(t), Seq(
-        Row(3, "c", 3)))
-    }
-  }
-
-  test("DeleteFrom with v2 filtering: basic - delete with where clause") {
-    val t = "testv2filter.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t WHERE id = 2")
-      checkAnswer(spark.table(t), Seq(
-        Row(3, "c", 3)))
-    }
-  }
-
-  test("DeleteFrom: delete from aliased target table") {
-    val t = "testcat.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2")
-      checkAnswer(spark.table(t), Seq(
-        Row(3, "c", 3)))
-    }
-  }
-
-  test("DeleteFrom with v2 filtering: delete from aliased target table") {
-    val t = "testv2filter.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2")
-      checkAnswer(spark.table(t), Seq(
-        Row(3, "c", 3)))
-    }
-  }
-
-  test("DeleteFrom: normalize attribute names") {
-    val t = "testcat.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2")
-      checkAnswer(spark.table(t), Seq(
-        Row(3, "c", 3)))
-    }
-  }
-
-  test("DeleteFrom with v2 filtering: normalize attribute names") {
-    val t = "testv2filter.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2")
-      checkAnswer(spark.table(t), Seq(
-        Row(3, "c", 3)))
-    }
-  }
-
-  test("DeleteFrom: fail if has subquery") {
-    val t = "testcat.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      val exc = intercept[AnalysisException] {
-        sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)")
-      }
-
-      assert(spark.table(t).count === 3)
-      assert(exc.getMessage.contains("Delete by condition with subquery is not supported"))
-    }
-  }
-
-  test("DeleteFrom with v2 filtering: fail if has subquery") {
-    val t = "testv2filter.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      val exc = intercept[AnalysisException] {
-        sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)")
-      }
-
-      assert(spark.table(t).count === 3)
-      assert(exc.getMessage.contains("Delete by condition with subquery is not supported"))
-    }
-  }
-
-  test("DeleteFrom: delete with unsupported predicates") {
-    val t = "testcat.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      val exc = intercept[AnalysisException] {
-        sql(s"DELETE FROM $t WHERE id > 3 AND p > 3")
-      }
-
-      assert(spark.table(t).count === 3)
-      assert(exc.getMessage.contains(s"Cannot delete from table $t"))
-    }
-  }
-
-  test("DeleteFrom with v2 filtering: delete with unsupported predicates") {
-    val t = "testv2filter.ns1.ns2.tbl"
-    withTable(t) {
-      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo")
-      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-      val exc = intercept[AnalysisException] {
-        sql(s"DELETE FROM $t WHERE id > 3 AND p > 3")
-      }
-
-      assert(spark.table(t).count === 3)
-      assert(exc.getMessage.contains(s"Cannot delete from table $t"))
-    }
-  }
-
   test("DeleteFrom: - delete with invalid predicate") {
     val t = "testcat.ns1.ns2.tbl"
     withTable(t) {
@@ -1682,37 +1549,6 @@ class DataSourceV2SQLSuite
     }
   }
 
-  test("DeleteFrom: DELETE is only supported with v2 tables") {
-    // unset this config to use the default v2 session catalog.
-    spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
-    val v1Table = "tbl"
-    withTable(v1Table) {
-      sql(s"CREATE TABLE $v1Table" +
-          s" USING ${classOf[SimpleScanSource].getName} OPTIONS (from=0,to=1)")
-      val exc = intercept[AnalysisException] {
-        sql(s"DELETE FROM $v1Table WHERE i = 2")
-      }
-
-      assert(exc.getMessage.contains("DELETE is only supported with v2 tables"))
-    }
-  }
-
-  test("SPARK-33652: DeleteFrom should refresh caches referencing the table") {
-    val t = "testcat.ns1.ns2.tbl"
-    val view = "view"
-    withTable(t) {
-      withTempView(view) {
-        sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
-        sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
-        sql(s"CACHE TABLE view AS SELECT id FROM $t")
-        assert(spark.table(view).count() == 3)
-
-        sql(s"DELETE FROM $t WHERE id = 2")
-        assert(spark.table(view).count() == 1)
-      }
-    }
-  }
-
   test("UPDATE TABLE") {
     val t = "testcat.ns1.ns2.tbl"
     withTable(t) {
@@ -2272,7 +2108,7 @@ class DataSourceV2SQLSuite
     val t1 = s"${catalogAndNamespace}table"
     withTable(t1) {
       sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
-        "PARTITIONED BY (bucket(4, id), id)")
+          "PARTITIONED BY (bucket(4, id), id)")
 
       val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0")
       val dfQuery = spark.table(t1).filter("index = 0")
@@ -2796,15 +2632,6 @@ class DataSourceV2SQLSuite
     assert(e.message.contains(s"$sqlCommand is not supported for v2 tables"))
   }
 
-  private def assertAnalysisError(
-      sqlStatement: String,
-      expectedError: String): Unit = {
-    val ex = intercept[AnalysisException] {
-      sql(sqlStatement)
-    }
-    assert(ex.getMessage.contains(expectedError))
-  }
-
   private def assertAnalysisErrorClass(
       sqlStatement: String,
       expectedErrorClass: String,
@@ -2815,8 +2642,12 @@ class DataSourceV2SQLSuite
     assert(ex.getErrorClass == expectedErrorClass)
     assert(ex.messageParameters.sameElements(expectedErrorMessageParameters))
   }
+
 }
 
+class DataSourceV2SQLSuiteV2Filter extends DataSourceV2SQLSuite {
+  override protected val catalogAndNamespace = "testv2filter.ns1.ns2."
+}
 
 /** Used as a V2 DataSource for V2SessionCatalog DDL */
 class FakeV2Provider extends SimpleTableProvider {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala
new file mode 100644
index 00000000000..5ed64df6280
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.sources.SimpleScanSource
+
+/**
+ * A collection of "DELETE" tests that can be run through the SQL APIs.
+ */
+trait DeleteFromTests extends DatasourceV2SQLBase {
+
+  protected val catalogAndNamespace: String
+
+  test("DeleteFrom with v2 filtering: basic - delete all") {
+    val t = s"${catalogAndNamespace}tbl"
+    withTable(t) {
+      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+      sql(s"DELETE FROM $t")
+      checkAnswer(spark.table(t), Seq())
+    }
+  }
+
+  test("DeleteFrom with v2 filtering: basic - delete with where clause") {
+    val t = s"${catalogAndNamespace}tbl"
+    withTable(t) {
+      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+      sql(s"DELETE FROM $t WHERE id = 2")
+      checkAnswer(spark.table(t), Seq(
+        Row(3, "c", 3)))
+    }
+  }
+
+  test("DeleteFrom with v2 filtering: delete from aliased target table") {
+    val t = s"${catalogAndNamespace}tbl"
+    withTable(t) {
+      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+      sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2")
+      checkAnswer(spark.table(t), Seq(
+        Row(3, "c", 3)))
+    }
+  }
+
+  test("DeleteFrom with v2 filtering: normalize attribute names") {
+    val t = s"${catalogAndNamespace}tbl"
+    withTable(t) {
+      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+      sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2")
+      checkAnswer(spark.table(t), Seq(
+        Row(3, "c", 3)))
+    }
+  }
+
+  test("DeleteFrom with v2 filtering: fail if has subquery") {
+    val t = s"${catalogAndNamespace}tbl"
+    withTable(t) {
+      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+      val exc = intercept[AnalysisException] {
+        sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)")
+      }
+
+      assert(spark.table(t).count === 3)
+      assert(exc.getMessage.contains("Delete by condition with subquery is not supported"))
+    }
+  }
+
+  test("DeleteFrom with v2 filtering: delete with unsupported predicates") {
+    val t = s"${catalogAndNamespace}tbl"
+    withTable(t) {
+      sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo")
+      sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+      val exc = intercept[AnalysisException] {
+        sql(s"DELETE FROM $t WHERE id > 3 AND p > 3")
+      }
+
+      assert(spark.table(t).count === 3)
+      assert(exc.getMessage.contains(s"Cannot delete from table $t"))
+    }
+  }
+
+  test("DeleteFrom: DELETE is only supported with v2 tables") {
+    // unset this config to use the default v2 session catalog.
+    spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
+    val v1Table = "tbl"
+    withTable(v1Table) {
+      sql(s"CREATE TABLE $v1Table" +
+        s" USING ${classOf[SimpleScanSource].getName} OPTIONS (from=0,to=1)")
+      val exc = intercept[AnalysisException] {
+        sql(s"DELETE FROM $v1Table WHERE i = 2")
+      }
+
+      assert(exc.getMessage.contains("DELETE is only supported with v2 tables"))
+    }
+  }
+
+  test("SPARK-33652: DeleteFrom should refresh caches referencing the table") {
+    val t = s"${catalogAndNamespace}tbl"
+    val view = "view"
+    withTable(t) {
+      withTempView(view) {
+        sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+        sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+        sql(s"CACHE TABLE view AS SELECT id FROM $t")
+        assert(spark.table(view).count() == 3)
+
+        sql(s"DELETE FROM $t WHERE id = 2")
+        assert(spark.table(view).count() == 1)
+      }
+    }
+  }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index fe4f70e57ef..992c46cc6cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveM
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryBaseTable, SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability}
 import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
 import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder}
@@ -359,7 +359,7 @@ class InMemoryTableWithV1Fallback(
     }
 
     override def overwrite(filters: Array[Filter]): WriteBuilder = {
-      val keys = InMemoryBaseTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
+      val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
       dataMap --= keys
       mode = "overwrite"
       this


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org