You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hu...@apache.org on 2022/08/15 17:58:39 UTC
[spark] branch master updated: [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite
This is an automated email from the ASF dual-hosted git repository.
huaxingao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 1103343e71f [SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite
1103343e71f is described below
commit 1103343e71fbcb478fa41941c87d2c28b0c09281
Author: huaxingao <hu...@apple.com>
AuthorDate: Mon Aug 15 10:58:14 2022 -0700
[SPARK-40064][SQL] Use V2 Filter in SupportsOverwrite
### What changes were proposed in this pull request?
Migrate `SupportsOverwrite` to use V2 Filter
### Why are the changes needed?
this is part of the V2Filter migration work
### Does this PR introduce _any_ user-facing change?
Yes
add `SupportsOverwriteV2`
### How was this patch tested?
new tests
Closes #37502 from huaxingao/v2overwrite.
Authored-by: huaxingao <hu...@apple.com>
Signed-off-by: huaxingao <hu...@apple.com>
---
.../sql/connector/catalog/TableCapability.java | 2 +-
.../connector/write/SupportsDynamicOverwrite.java | 2 +-
.../sql/connector/write/SupportsOverwrite.java | 31 ++-
...ortsOverwrite.java => SupportsOverwriteV2.java} | 31 ++-
.../sql/connector/catalog/InMemoryBaseTable.scala | 138 +++---------
.../sql/connector/catalog/InMemoryTable.scala | 99 ++++++++-
.../catalog/InMemoryTableWithV2Filter.scala | 72 +++++--
.../sql/execution/datasources/v2/V2Writes.scala | 23 +-
.../spark/sql/connector/DataSourceV2SQLSuite.scala | 233 +++------------------
.../spark/sql/connector/DeleteFromTests.scala | 132 ++++++++++++
.../spark/sql/connector/V1WriteFallbackSuite.scala | 4 +-
11 files changed, 412 insertions(+), 355 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java
index 5bb42fb4b31..5732c0f3af4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCapability.java
@@ -76,7 +76,7 @@ public enum TableCapability {
* Signals that the table can replace existing data that matches a filter with appended data in
* a write operation.
* <p>
- * See {@link org.apache.spark.sql.connector.write.SupportsOverwrite}.
+ * See {@link org.apache.spark.sql.connector.write.SupportsOverwriteV2}.
*/
OVERWRITE_BY_FILTER,
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java
index 422cd71d345..0288a679891 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsDynamicOverwrite.java
@@ -27,7 +27,7 @@ import org.apache.spark.annotation.Evolving;
* write does not contain data will remain unchanged.
* <p>
* This is provided to implement SQL compatible with Hive table operations but is not recommended.
- * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data.
+ * Instead, use the {@link SupportsOverwriteV2 overwrite by filter API} to explicitly replace data.
*
* @since 3.0.0
*/
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
index b4e60257942..51bec236088 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
@@ -18,6 +18,8 @@
package org.apache.spark.sql.connector.write;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.sql.internal.connector.PredicateUtils;
import org.apache.spark.sql.sources.AlwaysTrue$;
import org.apache.spark.sql.sources.Filter;
@@ -30,7 +32,24 @@ import org.apache.spark.sql.sources.Filter;
* @since 3.0.0
*/
@Evolving
-public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
+public interface SupportsOverwrite extends SupportsOverwriteV2 {
+
+ /**
+ * Checks whether it is possible to overwrite data from a data source table that matches filter
+ * expressions.
+ * <p>
+ * Rows should be overwritten from the data source iff all of the filter expressions match.
+ * That is, the expressions must be interpreted as a set of filters that are ANDed together.
+ *
+ * @param filters V2 filter expressions, used to match data to overwrite
+ * @return true if the delete operation can be performed
+ *
+ * @since 3.4.0
+ */
+ default boolean canOverwrite(Filter[] filters) {
+ return true;
+ }
+
/**
* Configures a write to replace data matching the filters with data committed in the write.
* <p>
@@ -42,6 +61,16 @@ public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
*/
WriteBuilder overwrite(Filter[] filters);
+ default boolean canOverwrite(Predicate[] predicates) {
+ Filter[] v1Filters = PredicateUtils.toV1(predicates);
+ if (v1Filters.length < predicates.length) return false;
+ return this.canOverwrite(v1Filters);
+ }
+
+ default WriteBuilder overwrite(Predicate[] predicates) {
+ return this.overwrite(PredicateUtils.toV1(predicates));
+ }
+
@Override
default WriteBuilder truncate() {
return overwrite(new Filter[] { AlwaysTrue$.MODULE$ });
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java
similarity index 60%
copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java
index b4e60257942..c1fcbfd38e1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwrite.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/SupportsOverwriteV2.java
@@ -18,8 +18,8 @@
package org.apache.spark.sql.connector.write;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.AlwaysTrue$;
-import org.apache.spark.sql.sources.Filter;
+import org.apache.spark.sql.connector.expressions.filter.AlwaysTrue;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
/**
* Write builder trait for tables that support overwrite by filter.
@@ -27,23 +27,40 @@ import org.apache.spark.sql.sources.Filter;
* Overwriting data by filter will delete any data that matches the filter and replace it with data
* that is committed in the write.
*
- * @since 3.0.0
+ * @since 3.4.0
*/
@Evolving
-public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate {
+public interface SupportsOverwriteV2 extends WriteBuilder, SupportsTruncate {
+
+ /**
+ * Checks whether it is possible to overwrite data from a data source table that matches filter
+ * expressions.
+ * <p>
+ * Rows should be overwritten from the data source iff all of the filter expressions match.
+ * That is, the expressions must be interpreted as a set of filters that are ANDed together.
+ *
+ * @param predicates V2 filter expressions, used to match data to overwrite
+ * @return true if the delete operation can be performed
+ *
+ * @since 3.4.0
+ */
+ default boolean canOverwrite(Predicate[] predicates) {
+ return true;
+ }
+
/**
* Configures a write to replace data matching the filters with data committed in the write.
* <p>
* Rows must be deleted from the data source if and only if all of the filters match. That is,
* filters must be interpreted as ANDed together.
*
- * @param filters filters used to match data to overwrite
+ * @param predicates filters used to match data to overwrite
* @return this write builder for method chaining
*/
- WriteBuilder overwrite(Filter[] filters);
+ WriteBuilder overwrite(Predicate[] predicates);
@Override
default WriteBuilder truncate() {
- return overwrite(new Filter[] { AlwaysTrue$.MODULE$ });
+ return overwrite(new Predicate[] { new AlwaysTrue() });
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
index 1f8b416cf55..f139399ed76 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala
@@ -45,7 +45,7 @@ import org.apache.spark.unsafe.types.UTF8String
/**
* A simple in-memory table. Rows are stored as a buffered group produced by each output task.
*/
-class InMemoryBaseTable(
+abstract class InMemoryBaseTable(
val name: String,
val schema: StructType,
override val partitioning: Array[Transform],
@@ -337,59 +337,39 @@ class InMemoryBaseTable(
}
}
- override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
- InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
- InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+ abstract class InMemoryWriterBuilder() extends SupportsTruncate with SupportsDynamicOverwrite
+ with SupportsStreamingUpdateAsAppend {
- new WriteBuilder with SupportsTruncate with SupportsOverwrite
- with SupportsDynamicOverwrite with SupportsStreamingUpdateAsAppend {
+ protected var writer: BatchWrite = Append
+ protected var streamingWriter: StreamingWrite = StreamingAppend
- private var writer: BatchWrite = Append
- private var streamingWriter: StreamingWrite = StreamingAppend
-
- override def truncate(): WriteBuilder = {
- assert(writer == Append)
- writer = TruncateAndAppend
- streamingWriter = StreamingTruncateAndAppend
- this
- }
-
- override def overwrite(filters: Array[Filter]): WriteBuilder = {
- assert(writer == Append)
- writer = new Overwrite(filters)
- streamingWriter = new StreamingNotSupportedOperation(
- s"overwrite (${filters.mkString("filters(", ", ", ")")})")
- this
- }
-
- override def overwriteDynamicPartitions(): WriteBuilder = {
- assert(writer == Append)
- writer = DynamicOverwrite
- streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
- this
- }
+ override def overwriteDynamicPartitions(): WriteBuilder = {
+ assert(writer == Append)
+ writer = DynamicOverwrite
+ streamingWriter = new StreamingNotSupportedOperation("overwriteDynamicPartitions")
+ this
+ }
- override def build(): Write = new Write with RequiresDistributionAndOrdering {
- override def requiredDistribution: Distribution = distribution
+ override def build(): Write = new Write with RequiresDistributionAndOrdering {
+ override def requiredDistribution: Distribution = distribution
- override def distributionStrictlyRequired: Boolean = isDistributionStrictlyRequired
+ override def distributionStrictlyRequired: Boolean = isDistributionStrictlyRequired
- override def requiredOrdering: Array[SortOrder] = ordering
+ override def requiredOrdering: Array[SortOrder] = ordering
- override def requiredNumPartitions(): Int = {
- numPartitions.getOrElse(0)
- }
+ override def requiredNumPartitions(): Int = {
+ numPartitions.getOrElse(0)
+ }
- override def toBatch: BatchWrite = writer
+ override def toBatch: BatchWrite = writer
- override def toStreaming: StreamingWrite = streamingWriter match {
- case exc: StreamingNotSupportedOperation => exc.throwsException()
- case s => s
- }
+ override def toStreaming: StreamingWrite = streamingWriter match {
+ case exc: StreamingNotSupportedOperation => exc.throwsException()
+ case s => s
+ }
- override def supportedCustomMetrics(): Array[CustomMetric] = {
- Array(new InMemorySimpleCustomMetric)
- }
+ override def supportedCustomMetrics(): Array[CustomMetric] = {
+ Array(new InMemorySimpleCustomMetric)
}
}
}
@@ -402,7 +382,7 @@ class InMemoryBaseTable(
override def abort(messages: Array[WriterCommitMessage]): Unit = {}
}
- private object Append extends TestBatchWrite {
+ protected object Append extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
@@ -416,24 +396,14 @@ class InMemoryBaseTable(
}
}
- private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
- override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
- val deleteKeys = InMemoryBaseTable.filtersToKeys(
- dataMap.keys, partCols.map(_.toSeq.quoted), filters)
- dataMap --= deleteKeys
- withData(messages.map(_.asInstanceOf[BufferedRows]))
- }
- }
-
- private object TruncateAndAppend extends TestBatchWrite {
+ protected object TruncateAndAppend extends TestBatchWrite {
override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
dataMap.clear
withData(messages.map(_.asInstanceOf[BufferedRows]))
}
}
- private abstract class TestStreamingWrite extends StreamingWrite {
+ protected abstract class TestStreamingWrite extends StreamingWrite {
def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory = {
BufferedRowsWriterFactory
}
@@ -441,7 +411,7 @@ class InMemoryBaseTable(
def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
}
- private class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
+ protected class StreamingNotSupportedOperation(operation: String) extends TestStreamingWrite {
override def createStreamingWriterFactory(info: PhysicalWriteInfo): StreamingDataWriterFactory =
throwsException()
@@ -463,7 +433,7 @@ class InMemoryBaseTable(
}
}
- private object StreamingTruncateAndAppend extends TestStreamingWrite {
+ protected object StreamingTruncateAndAppend extends TestStreamingWrite {
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {
dataMap.synchronized {
dataMap.clear
@@ -476,46 +446,7 @@ class InMemoryBaseTable(
object InMemoryBaseTable {
val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite"
- def filtersToKeys(
- keys: Iterable[Seq[Any]],
- partitionNames: Seq[String],
- filters: Array[Filter]): Iterable[Seq[Any]] = {
- keys.filter { partValues =>
- filters.flatMap(splitAnd).forall {
- case EqualTo(attr, value) =>
- value == extractValue(attr, partitionNames, partValues)
- case EqualNullSafe(attr, value) =>
- val attrVal = extractValue(attr, partitionNames, partValues)
- if (attrVal == null && value === null) {
- true
- } else if (attrVal == null || value === null) {
- false
- } else {
- value == attrVal
- }
- case IsNull(attr) =>
- null == extractValue(attr, partitionNames, partValues)
- case IsNotNull(attr) =>
- null != extractValue(attr, partitionNames, partValues)
- case AlwaysTrue() => true
- case f =>
- throw new IllegalArgumentException(s"Unsupported filter type: $f")
- }
- }
- }
-
- def supportsFilters(filters: Array[Filter]): Boolean = {
- filters.flatMap(splitAnd).forall {
- case _: EqualTo => true
- case _: EqualNullSafe => true
- case _: IsNull => true
- case _: IsNotNull => true
- case _: AlwaysTrue => true
- case _ => false
- }
- }
-
- private def extractValue(
+ def extractValue(
attr: String,
partFieldNames: Seq[String],
partValues: Seq[Any]): Any = {
@@ -527,13 +458,6 @@ object InMemoryBaseTable {
}
}
- private def splitAnd(filter: Filter): Seq[Filter] = {
- filter match {
- case And(left, right) => splitAnd(left) ++ splitAnd(right)
- case _ => filter :: Nil
- }
- }
-
def maybeSimulateFailedTableWrite(tableOptions: CaseInsensitiveStringMap): Unit = {
if (tableOptions.getBoolean(SIMULATE_FAILED_WRITE_OPTION, false)) {
throw new IllegalStateException("Manual write to table failure.")
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index b82641a5d24..cd6821c8739 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -19,10 +19,14 @@ package org.apache.spark.sql.connector.catalog
import java.util
+import org.scalatest.Assertions.assert
+
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
-import org.apache.spark.sql.sources.Filter
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwrite, WriteBuilder, WriterCommitMessage}
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A simple in-memory table. Rows are stored as a buffered group produced by each output task.
@@ -40,12 +44,12 @@ class InMemoryTable(
ordering, numPartitions, isDistributionStrictlyRequired) with SupportsDelete {
override def canDeleteWhere(filters: Array[Filter]): Boolean = {
- InMemoryBaseTable.supportsFilters(filters)
+ InMemoryTable.supportsFilters(filters)
}
override def deleteWhere(filters: Array[Filter]): Unit = dataMap.synchronized {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
- dataMap --= InMemoryBaseTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
+ dataMap --= InMemoryTable.filtersToKeys(dataMap.keys, partCols.map(_.toSeq.quoted), filters)
}
override def withData(data: Array[BufferedRows]): InMemoryTable = {
@@ -64,4 +68,93 @@ class InMemoryTable(
})
this
}
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+
+ new InMemoryWriterBuilderWithOverWrite()
+ }
+
+ private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
+ with SupportsOverwrite {
+
+ override def truncate(): WriteBuilder = {
+ assert(writer == Append)
+ writer = TruncateAndAppend
+ streamingWriter = StreamingTruncateAndAppend
+ this
+ }
+
+ override def overwrite(filters: Array[Filter]): WriteBuilder = {
+ assert(writer == Append)
+ writer = new Overwrite(filters)
+ streamingWriter = new StreamingNotSupportedOperation(
+ s"overwrite (${filters.mkString("filters(", ", ", ")")})")
+ this
+ }
+
+ override def canOverwrite(filters: Array[Filter]): Boolean = {
+ InMemoryTable.supportsFilters(filters)
+ }
+ }
+
+ private class Overwrite(filters: Array[Filter]) extends TestBatchWrite {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ val deleteKeys = InMemoryTable.filtersToKeys(
+ dataMap.keys, partCols.map(_.toSeq.quoted), filters)
+ dataMap --= deleteKeys
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
+}
+
+object InMemoryTable {
+
+ def filtersToKeys(
+ keys: Iterable[Seq[Any]],
+ partitionNames: Seq[String],
+ filters: Array[Filter]): Iterable[Seq[Any]] = {
+ keys.filter { partValues =>
+ filters.flatMap(splitAnd).forall {
+ case EqualTo(attr, value) =>
+ value == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+ case EqualNullSafe(attr, value) =>
+ val attrVal = InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+ if (attrVal == null && value == null) {
+ true
+ } else if (attrVal == null || value == null) {
+ false
+ } else {
+ value == attrVal
+ }
+ case IsNull(attr) =>
+ null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+ case IsNotNull(attr) =>
+ null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
+ case AlwaysTrue() => true
+ case f =>
+ throw new IllegalArgumentException(s"Unsupported filter type: $f")
+ }
+ }
+ }
+
+ def supportsFilters(filters: Array[Filter]): Boolean = {
+ filters.flatMap(splitAnd).forall {
+ case _: EqualTo => true
+ case _: EqualNullSafe => true
+ case _: IsNull => true
+ case _: IsNotNull => true
+ case _: AlwaysTrue => true
+ case _ => false
+ }
+ }
+
+ private def splitAnd(filter: Filter): Seq[Filter] = {
+ filter match {
+ case And(left, right) => splitAnd(left) ++ splitAnd(right)
+ case _ => filter :: Nil
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
index 48000dd0d98..b4285f31dd7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableWithV2Filter.scala
@@ -19,9 +19,12 @@ package org.apache.spark.sql.connector.catalog
import java.util
+import org.scalatest.Assertions.assert
+
import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue, NamedReference, Transform}
import org.apache.spark.sql.connector.expressions.filter.{And, Predicate}
import org.apache.spark.sql.connector.read.{InputPartition, Scan, ScanBuilder, SupportsRuntimeV2Filtering}
+import org.apache.spark.sql.connector.write.{LogicalWriteInfo, SupportsOverwriteV2, WriteBuilder, WriterCommitMessage}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -32,8 +35,8 @@ class InMemoryTableWithV2Filter(
properties: util.Map[String, String])
extends InMemoryBaseTable(name, schema, partitioning, properties) with SupportsDeleteV2 {
- override def canDeleteWhere(filters: Array[Predicate]): Boolean = {
- InMemoryTableWithV2Filter.supportsFilters(filters)
+ override def canDeleteWhere(predicates: Array[Predicate]): Boolean = {
+ InMemoryTableWithV2Filter.supportsPredicates(predicates)
}
override def deleteWhere(filters: Array[Predicate]): Unit = dataMap.synchronized {
@@ -84,6 +87,46 @@ class InMemoryTableWithV2Filter(
}
}
}
+
+ override def newWriteBuilder(info: LogicalWriteInfo): WriteBuilder = {
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(new CaseInsensitiveStringMap(properties))
+ InMemoryBaseTable.maybeSimulateFailedTableWrite(info.options)
+
+ new InMemoryWriterBuilderWithOverWrite()
+ }
+
+ private class InMemoryWriterBuilderWithOverWrite() extends InMemoryWriterBuilder
+ with SupportsOverwriteV2 {
+
+ override def truncate(): WriteBuilder = {
+ assert(writer == Append)
+ writer = TruncateAndAppend
+ streamingWriter = StreamingTruncateAndAppend
+ this
+ }
+
+ override def overwrite(predicates: Array[Predicate]): WriteBuilder = {
+ assert(writer == Append)
+ writer = new Overwrite(predicates)
+ streamingWriter = new StreamingNotSupportedOperation(
+ s"overwrite (${predicates.mkString("filters(", ", ", ")")})")
+ this
+ }
+
+ override def canOverwrite(predicates: Array[Predicate]): Boolean = {
+ InMemoryTableWithV2Filter.supportsPredicates(predicates)
+ }
+ }
+
+ private class Overwrite(predicates: Array[Predicate]) extends TestBatchWrite {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
+ override def commit(messages: Array[WriterCommitMessage]): Unit = dataMap.synchronized {
+ val deleteKeys = InMemoryTableWithV2Filter.filtersToKeys(
+ dataMap.keys, partCols.map(_.toSeq.quoted), predicates)
+ dataMap --= deleteKeys
+ withData(messages.map(_.asInstanceOf[BufferedRows]))
+ }
+ }
}
object InMemoryTableWithV2Filter {
@@ -96,9 +139,10 @@ object InMemoryTableWithV2Filter {
filters.flatMap(splitAnd).forall {
case p: Predicate if p.name().equals("=") =>
p.children()(1).asInstanceOf[LiteralValue[_]].value ==
- extractValue(p.children()(0).toString, partitionNames, partValues)
+ InMemoryBaseTable.extractValue(p.children()(0).toString, partitionNames, partValues)
case p: Predicate if p.name().equals("<=>") =>
- val attrVal = extractValue(p.children()(0).toString, partitionNames, partValues)
+ val attrVal = InMemoryBaseTable
+ .extractValue(p.children()(0).toString, partitionNames, partValues)
val value = p.children()(1).asInstanceOf[LiteralValue[_]].value
if (attrVal == null && value == null) {
true
@@ -109,10 +153,10 @@ object InMemoryTableWithV2Filter {
}
case p: Predicate if p.name().equals("IS NULL") =>
val attr = p.children()(0).toString
- null == extractValue(attr, partitionNames, partValues)
+ null == InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
case p: Predicate if p.name().equals("IS NOT NULL") =>
val attr = p.children()(0).toString
- null != extractValue(attr, partitionNames, partValues)
+ null != InMemoryBaseTable.extractValue(attr, partitionNames, partValues)
case p: Predicate if p.name().equals("ALWAYS_TRUE") => true
case f =>
throw new IllegalArgumentException(s"Unsupported filter type: $f")
@@ -120,8 +164,8 @@ object InMemoryTableWithV2Filter {
}
}
- def supportsFilters(filters: Array[Predicate]): Boolean = {
- filters.flatMap(splitAnd).forall {
+ def supportsPredicates(predicates: Array[Predicate]): Boolean = {
+ predicates.flatMap(splitAnd).forall {
case p: Predicate if p.name().equals("=") => true
case p: Predicate if p.name().equals("<=>") => true
case p: Predicate if p.name().equals("IS NULL") => true
@@ -131,18 +175,6 @@ object InMemoryTableWithV2Filter {
}
}
- private def extractValue(
- attr: String,
- partFieldNames: Seq[String],
- partValues: Seq[Any]): Any = {
- partFieldNames.zipWithIndex.find(_._1 == attr) match {
- case Some((_, partIndex)) =>
- partValues(partIndex)
- case _ =>
- throw new IllegalArgumentException(s"Unknown filter attribute: $attr")
- }
- }
-
private def splitAnd(filter: Predicate): Seq[Predicate] = {
filter match {
case and: And => splitAnd(and.left()) ++ splitAnd(and.right())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
index 4422743c5ac..2d47d94ff1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2Writes.scala
@@ -24,12 +24,11 @@ import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Ove
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table}
-import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, Write, WriteBuilder}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.write.{LogicalWriteInfoImpl, SupportsDynamicOverwrite, SupportsOverwriteV2, SupportsTruncate, Write, WriteBuilder}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, WriteToMicroBatchDataSource}
import org.apache.spark.sql.internal.connector.SupportsStreamingUpdateAsAppend
-import org.apache.spark.sql.sources.{AlwaysTrue, Filter}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
@@ -49,21 +48,21 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
case o @ OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, options, _, None) =>
// fail if any filter cannot be converted. correctness depends on removing all matching data.
- val filters = splitConjunctivePredicates(deleteExpr).flatMap { pred =>
- val filter = DataSourceStrategy.translateFilter(pred, supportNestedPredicatePushdown = true)
- if (filter.isEmpty) {
+ val predicates = splitConjunctivePredicates(deleteExpr).flatMap { pred =>
+ val predicate = DataSourceV2Strategy.translateFilterV2(pred)
+ if (predicate.isEmpty) {
throw QueryCompilationErrors.cannotTranslateExpressionToSourceFilterError(pred)
}
- filter
+ predicate
}.toArray
val table = r.table
val writeBuilder = newWriteBuilder(table, options, query.schema)
val write = writeBuilder match {
- case builder: SupportsTruncate if isTruncate(filters) =>
+ case builder: SupportsTruncate if isTruncate(predicates) =>
builder.truncate().build()
- case builder: SupportsOverwrite =>
- builder.overwrite(filters).build()
+ case builder: SupportsOverwriteV2 if builder.canOverwrite(predicates) =>
+ builder.overwrite(predicates).build()
case _ =>
throw QueryExecutionErrors.overwriteTableByUnsupportedExpressionError(table)
}
@@ -123,8 +122,8 @@ object V2Writes extends Rule[LogicalPlan] with PredicateHelper {
}
}
- private def isTruncate(filters: Array[Filter]): Boolean = {
- filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue]
+ private def isTruncate(predicates: Array[Predicate]): Boolean = {
+ predicates.length == 1 && predicates(0).name().equals("ALWAYS_TRUE")
}
private def newWriteBuilder(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index 9ec5be46fc2..629a5ac83c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -40,18 +40,13 @@ import org.apache.spark.sql.sources.SimpleScanSource
import org.apache.spark.sql.types.{LongType, MetadataBuilder, StringType, StructField, StructType}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.Utils
-class DataSourceV2SQLSuite
+abstract class DataSourceV2SQLSuite
extends InsertIntoTests(supportsDynamicOverwrite = true, includeSQLOnlyTests = true)
- with AlterTableTests with DatasourceV2SQLBase {
+ with DeleteFromTests with DatasourceV2SQLBase {
- import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
-
- private val v2Source = classOf[FakeV2Provider].getName
+ protected val v2Source = classOf[FakeV2Provider].getName
override protected val v2Format = v2Source
- override protected val catalogAndNamespace = "testcat.ns1.ns2."
- private val defaultUser: String = Utils.getCurrentUserName()
protected def doInsert(tableName: String, insert: DataFrame, mode: SaveMode): Unit = {
val tmpView = "tmp_view"
@@ -66,6 +61,20 @@ class DataSourceV2SQLSuite
checkAnswer(spark.table(tableName), expected)
}
+ protected def assertAnalysisError(
+ sqlStatement: String,
+ expectedError: String): Unit = {
+ val ex = intercept[AnalysisException] {
+ sql(sqlStatement)
+ }
+ assert(ex.getMessage.contains(expectedError))
+ }
+}
+
+class DataSourceV2SQLSuiteV1Filter extends DataSourceV2SQLSuite with AlterTableTests {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+ override protected val catalogAndNamespace = "testcat.ns1.ns2."
override def getTableMetadata(tableName: String): Table = {
val nameParts = spark.sessionState.sqlParser.parseMultipartIdentifier(tableName)
val v2Catalog = catalog(nameParts.head).asTableCatalog
@@ -622,8 +631,8 @@ class DataSourceV2SQLSuite
assert(table.partitioning.isEmpty)
assert(table.properties == withDefaultOwnership(Map("provider" -> v2Source)).asJava)
assert(table.schema == new StructType()
- .add("id", LongType)
- .add("data", StringType))
+ .add("id", LongType)
+ .add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -639,8 +648,8 @@ class DataSourceV2SQLSuite
assert(table.partitioning.isEmpty)
assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
assert(table.schema == new StructType()
- .add("id", LongType)
- .add("data", StringType))
+ .add("id", LongType)
+ .add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -659,8 +668,8 @@ class DataSourceV2SQLSuite
assert(table2.partitioning.isEmpty)
assert(table2.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
assert(table2.schema == new StructType()
- .add("id", LongType)
- .add("data", StringType))
+ .add("id", LongType)
+ .add("data", StringType))
val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source"))
@@ -677,8 +686,8 @@ class DataSourceV2SQLSuite
assert(table.partitioning.isEmpty)
assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
assert(table.schema == new StructType()
- .add("id", LongType)
- .add("data", StringType))
+ .add("id", LongType)
+ .add("data", StringType))
val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -708,8 +717,8 @@ class DataSourceV2SQLSuite
assert(table.partitioning.isEmpty)
assert(table.properties == withDefaultOwnership(Map("provider" -> "foo")).asJava)
assert(table.schema == new StructType()
- .add("id", LongType)
- .add("data", StringType))
+ .add("id", LongType)
+ .add("data", StringType))
val rdd = sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows)
checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source"))
@@ -1526,148 +1535,6 @@ class DataSourceV2SQLSuite
assert(e.message.contains("REPLACE TABLE is only supported with v2 tables"))
}
- test("DeleteFrom: basic - delete all") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t")
- checkAnswer(spark.table(t), Seq())
- }
- }
-
- test("DeleteFrom with v2 filtering: basic - delete all") {
- val t = "testv2filter.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t")
- checkAnswer(spark.table(t), Seq())
- }
- }
-
- test("DeleteFrom: basic - delete with where clause") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t WHERE id = 2")
- checkAnswer(spark.table(t), Seq(
- Row(3, "c", 3)))
- }
- }
-
- test("DeleteFrom with v2 filtering: basic - delete with where clause") {
- val t = "testv2filter.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t WHERE id = 2")
- checkAnswer(spark.table(t), Seq(
- Row(3, "c", 3)))
- }
- }
-
- test("DeleteFrom: delete from aliased target table") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2")
- checkAnswer(spark.table(t), Seq(
- Row(3, "c", 3)))
- }
- }
-
- test("DeleteFrom with v2 filtering: delete from aliased target table") {
- val t = "testv2filter.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2")
- checkAnswer(spark.table(t), Seq(
- Row(3, "c", 3)))
- }
- }
-
- test("DeleteFrom: normalize attribute names") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2")
- checkAnswer(spark.table(t), Seq(
- Row(3, "c", 3)))
- }
- }
-
- test("DeleteFrom with v2 filtering: normalize attribute names") {
- val t = "testv2filter.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2")
- checkAnswer(spark.table(t), Seq(
- Row(3, "c", 3)))
- }
- }
-
- test("DeleteFrom: fail if has subquery") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- val exc = intercept[AnalysisException] {
- sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)")
- }
-
- assert(spark.table(t).count === 3)
- assert(exc.getMessage.contains("Delete by condition with subquery is not supported"))
- }
- }
-
- test("DeleteFrom with v2 filtering: fail if has subquery") {
- val t = "testv2filter.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- val exc = intercept[AnalysisException] {
- sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)")
- }
-
- assert(spark.table(t).count === 3)
- assert(exc.getMessage.contains("Delete by condition with subquery is not supported"))
- }
- }
-
- test("DeleteFrom: delete with unsupported predicates") {
- val t = "testcat.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- val exc = intercept[AnalysisException] {
- sql(s"DELETE FROM $t WHERE id > 3 AND p > 3")
- }
-
- assert(spark.table(t).count === 3)
- assert(exc.getMessage.contains(s"Cannot delete from table $t"))
- }
- }
-
- test("DeleteFrom with v2 filtering: delete with unsupported predicates") {
- val t = "testv2filter.ns1.ns2.tbl"
- withTable(t) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- val exc = intercept[AnalysisException] {
- sql(s"DELETE FROM $t WHERE id > 3 AND p > 3")
- }
-
- assert(spark.table(t).count === 3)
- assert(exc.getMessage.contains(s"Cannot delete from table $t"))
- }
- }
-
test("DeleteFrom: - delete with invalid predicate") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
@@ -1682,37 +1549,6 @@ class DataSourceV2SQLSuite
}
}
- test("DeleteFrom: DELETE is only supported with v2 tables") {
- // unset this config to use the default v2 session catalog.
- spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
- val v1Table = "tbl"
- withTable(v1Table) {
- sql(s"CREATE TABLE $v1Table" +
- s" USING ${classOf[SimpleScanSource].getName} OPTIONS (from=0,to=1)")
- val exc = intercept[AnalysisException] {
- sql(s"DELETE FROM $v1Table WHERE i = 2")
- }
-
- assert(exc.getMessage.contains("DELETE is only supported with v2 tables"))
- }
- }
-
- test("SPARK-33652: DeleteFrom should refresh caches referencing the table") {
- val t = "testcat.ns1.ns2.tbl"
- val view = "view"
- withTable(t) {
- withTempView(view) {
- sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
- sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
- sql(s"CACHE TABLE view AS SELECT id FROM $t")
- assert(spark.table(view).count() == 3)
-
- sql(s"DELETE FROM $t WHERE id = 2")
- assert(spark.table(view).count() == 1)
- }
- }
- }
-
test("UPDATE TABLE") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
@@ -2272,7 +2108,7 @@ class DataSourceV2SQLSuite
val t1 = s"${catalogAndNamespace}table"
withTable(t1) {
sql(s"CREATE TABLE $t1 (id bigint, data string) USING $v2Format " +
- "PARTITIONED BY (bucket(4, id), id)")
+ "PARTITIONED BY (bucket(4, id), id)")
val sqlQuery = spark.sql(s"SELECT * FROM $t1 WHERE index = 0")
val dfQuery = spark.table(t1).filter("index = 0")
@@ -2796,15 +2632,6 @@ class DataSourceV2SQLSuite
assert(e.message.contains(s"$sqlCommand is not supported for v2 tables"))
}
- private def assertAnalysisError(
- sqlStatement: String,
- expectedError: String): Unit = {
- val ex = intercept[AnalysisException] {
- sql(sqlStatement)
- }
- assert(ex.getMessage.contains(expectedError))
- }
-
private def assertAnalysisErrorClass(
sqlStatement: String,
expectedErrorClass: String,
@@ -2815,8 +2642,12 @@ class DataSourceV2SQLSuite
assert(ex.getErrorClass == expectedErrorClass)
assert(ex.messageParameters.sameElements(expectedErrorMessageParameters))
}
+
}
+class DataSourceV2SQLSuiteV2Filter extends DataSourceV2SQLSuite {
+ override protected val catalogAndNamespace = "testv2filter.ns1.ns2."
+}
/** Used as a V2 DataSource for V2SessionCatalog DDL */
class FakeV2Provider extends SimpleTableProvider {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala
new file mode 100644
index 00000000000..5ed64df6280
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DeleteFromTests.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import org.apache.spark.sql._
+import org.apache.spark.sql.internal.SQLConf.V2_SESSION_CATALOG_IMPLEMENTATION
+import org.apache.spark.sql.sources.SimpleScanSource
+
+/**
+ * A collection of "DELETE" tests that can be run through the SQL APIs.
+ */
+trait DeleteFromTests extends DatasourceV2SQLBase {
+
+ protected val catalogAndNamespace: String
+
+ test("DeleteFrom with v2 filtering: basic - delete all") {
+ val t = s"${catalogAndNamespace}tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ sql(s"DELETE FROM $t")
+ checkAnswer(spark.table(t), Seq())
+ }
+ }
+
+ test("DeleteFrom with v2 filtering: basic - delete with where clause") {
+ val t = s"${catalogAndNamespace}tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ sql(s"DELETE FROM $t WHERE id = 2")
+ checkAnswer(spark.table(t), Seq(
+ Row(3, "c", 3)))
+ }
+ }
+
+ test("DeleteFrom with v2 filtering: delete from aliased target table") {
+ val t = s"${catalogAndNamespace}tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ sql(s"DELETE FROM $t AS tbl WHERE tbl.id = 2")
+ checkAnswer(spark.table(t), Seq(
+ Row(3, "c", 3)))
+ }
+ }
+
+ test("DeleteFrom with v2 filtering: normalize attribute names") {
+ val t = s"${catalogAndNamespace}tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ sql(s"DELETE FROM $t AS tbl WHERE tbl.ID = 2")
+ checkAnswer(spark.table(t), Seq(
+ Row(3, "c", 3)))
+ }
+ }
+
+ test("DeleteFrom with v2 filtering: fail if has subquery") {
+ val t = s"${catalogAndNamespace}tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ val exc = intercept[AnalysisException] {
+ sql(s"DELETE FROM $t WHERE id IN (SELECT id FROM $t)")
+ }
+
+ assert(spark.table(t).count === 3)
+ assert(exc.getMessage.contains("Delete by condition with subquery is not supported"))
+ }
+ }
+
+ test("DeleteFrom with v2 filtering: delete with unsupported predicates") {
+ val t = s"${catalogAndNamespace}tbl"
+ withTable(t) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ val exc = intercept[AnalysisException] {
+ sql(s"DELETE FROM $t WHERE id > 3 AND p > 3")
+ }
+
+ assert(spark.table(t).count === 3)
+ assert(exc.getMessage.contains(s"Cannot delete from table $t"))
+ }
+ }
+
+ test("DeleteFrom: DELETE is only supported with v2 tables") {
+ // unset this config to use the default v2 session catalog.
+ spark.conf.unset(V2_SESSION_CATALOG_IMPLEMENTATION.key)
+ val v1Table = "tbl"
+ withTable(v1Table) {
+ sql(s"CREATE TABLE $v1Table" +
+ s" USING ${classOf[SimpleScanSource].getName} OPTIONS (from=0,to=1)")
+ val exc = intercept[AnalysisException] {
+ sql(s"DELETE FROM $v1Table WHERE i = 2")
+ }
+
+ assert(exc.getMessage.contains("DELETE is only supported with v2 tables"))
+ }
+ }
+
+ test("SPARK-33652: DeleteFrom should refresh caches referencing the table") {
+ val t = s"${catalogAndNamespace}tbl"
+ val view = "view"
+ withTable(t) {
+ withTempView(view) {
+ sql(s"CREATE TABLE $t (id bigint, data string, p int) USING foo PARTITIONED BY (id, p)")
+ sql(s"INSERT INTO $t VALUES (2L, 'a', 2), (2L, 'b', 3), (3L, 'c', 3)")
+ sql(s"CACHE TABLE view AS SELECT id FROM $t")
+ assert(spark.table(view).count() == 3)
+
+ sql(s"DELETE FROM $t WHERE id = 2")
+ assert(spark.table(view).count() == 1)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index fe4f70e57ef..992c46cc6cd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row, SaveM
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
-import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryBaseTable, SupportsRead, SupportsWrite, Table, TableCapability}
+import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable, SupportsRead, SupportsWrite, Table, TableCapability}
import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform, Transform}
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, V1Scan}
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, LogicalWriteInfoImpl, SupportsOverwrite, SupportsTruncate, V1Write, WriteBuilder}
@@ -359,7 +359,7 @@ class InMemoryTableWithV1Fallback(
}
override def overwrite(filters: Array[Filter]): WriteBuilder = {
- val keys = InMemoryBaseTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
+ val keys = InMemoryTable.filtersToKeys(dataMap.keys, partFieldNames, filters)
dataMap --= keys
mode = "overwrite"
this
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org