You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2021/02/04 04:11:41 UTC

[iceberg] branch master updated: Spark: Support UPDATE statements with subqueries (#2206)

This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b6fa61  Spark: Support UPDATE statements with subqueries (#2206)
4b6fa61 is described below

commit 4b6fa61d268efd80c584a4445c29e3236a67dcbe
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Wed Feb 3 20:11:29 2021 -0800

    Spark: Support UPDATE statements with subqueries (#2206)
---
 .../extensions/IcebergSparkSessionExtensions.scala |   6 +-
 .../analysis/DeleteFromTablePredicateCheck.scala   |  53 ----
 ...cala => RowLevelOperationsPredicateCheck.scala} |  31 ++-
 ...pCorrelatedPredicatesInRowLevelOperations.scala |  22 +-
 .../sql/catalyst/optimizer/RewriteUpdate.scala     |  47 +++-
 .../SparkRowLevelOperationsTestBase.java           |   5 +
 .../iceberg/spark/extensions/TestDelete.java       |   7 +-
 .../iceberg/spark/extensions/TestUpdate.java       | 279 +++++++++++++++++++++
 .../iceberg/spark/source/SparkMergeScan.java       |   4 +-
 9 files changed, 376 insertions(+), 78 deletions(-)

diff --git a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
index d9f8f49..78b17d3 100644
--- a/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
+++ b/spark3-extensions/src/main/scala/org/apache/iceberg/spark/extensions/IcebergSparkSessionExtensions.scala
@@ -21,10 +21,9 @@ package org.apache.iceberg.spark.extensions
 
 import org.apache.spark.sql.SparkSessionExtensions
 import org.apache.spark.sql.catalyst.analysis.AlignRowLevelOperations
-import org.apache.spark.sql.catalyst.analysis.DeleteFromTablePredicateCheck
-import org.apache.spark.sql.catalyst.analysis.MergeIntoTablePredicateCheck
 import org.apache.spark.sql.catalyst.analysis.ProcedureArgumentCoercion
 import org.apache.spark.sql.catalyst.analysis.ResolveProcedures
+import org.apache.spark.sql.catalyst.analysis.RowLevelOperationsPredicateCheck
 import org.apache.spark.sql.catalyst.optimizer.OptimizeConditionsInRowLevelOperations
 import org.apache.spark.sql.catalyst.optimizer.PullupCorrelatedPredicatesInRowLevelOperations
 import org.apache.spark.sql.catalyst.optimizer.RewriteDelete
@@ -43,8 +42,7 @@ class IcebergSparkSessionExtensions extends (SparkSessionExtensions => Unit) {
     extensions.injectResolutionRule { spark => ResolveProcedures(spark) }
     extensions.injectResolutionRule { _ => ProcedureArgumentCoercion }
     extensions.injectPostHocResolutionRule { spark => AlignRowLevelOperations(spark.sessionState.conf)}
-    extensions.injectCheckRule { _ => DeleteFromTablePredicateCheck }
-    extensions.injectCheckRule { _ => MergeIntoTablePredicateCheck }
+    extensions.injectCheckRule { _ => RowLevelOperationsPredicateCheck }
 
     // optimizer extensions
     extensions.injectOptimizerRule { _ => OptimizeConditionsInRowLevelOperations }
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeleteFromTablePredicateCheck.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeleteFromTablePredicateCheck.scala
deleted file mode 100644
index 68fca88..0000000
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeleteFromTablePredicateCheck.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.spark.sql.catalyst.analysis
-
-import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.expressions.InSubquery
-import org.apache.spark.sql.catalyst.expressions.Not
-import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation
-
-object DeleteFromTablePredicateCheck extends (LogicalPlan => Unit) {
-
-  override def apply(plan: LogicalPlan): Unit = {
-    plan foreach {
-      case DeleteFromTable(r, Some(condition)) if hasNullAwarePredicateWithinNot(condition) && isIcebergRelation(r) =>
-        // this limitation is present since SPARK-25154 fix is not yet available
-        // we use Not(EqualsNullSafe(cond, true)) when deciding which records to keep
-        // such conditions are rewritten by Spark as an existential join and currently Spark
-        // does not handle correctly NOT IN subqueries nested into other expressions
-        failAnalysis("Null-aware predicate sub-queries are not currently supported in DELETE")
-
-      case _ => // OK
-    }
-  }
-
-  private def hasNullAwarePredicateWithinNot(cond: Expression): Boolean = {
-    cond.find {
-      case Not(expr) if expr.find(_.isInstanceOf[InSubquery]).isDefined => true
-      case _ => false
-    }.isDefined
-  }
-
-  private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)
-}
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoTablePredicateCheck.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RowLevelOperationsPredicateCheck.scala
similarity index 58%
rename from spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoTablePredicateCheck.scala
rename to spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RowLevelOperationsPredicateCheck.scala
index 179b4aa..0beab49 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/MergeIntoTablePredicateCheck.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RowLevelOperationsPredicateCheck.scala
@@ -21,20 +21,40 @@ package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.InSubquery
+import org.apache.spark.sql.catalyst.expressions.Not
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.plans.logical.DeleteAction
+import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable
 import org.apache.spark.sql.catalyst.plans.logical.InsertAction
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.logical.MergeIntoTable
 import org.apache.spark.sql.catalyst.plans.logical.UpdateAction
+import org.apache.spark.sql.catalyst.plans.logical.UpdateTable
 import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation
 
-object MergeIntoTablePredicateCheck extends (LogicalPlan => Unit) {
+object RowLevelOperationsPredicateCheck extends (LogicalPlan => Unit) {
 
   override def apply(plan: LogicalPlan): Unit = {
     plan foreach {
+      case DeleteFromTable(r, Some(condition)) if hasNullAwarePredicateWithinNot(condition) && isIcebergRelation(r) =>
+        // this limitation is present since SPARK-25154 fix is not yet available
+        // we use Not(EqualsNullSafe(cond, true)) when deciding which records to keep
+        // such conditions are rewritten by Spark as an existential join and currently Spark
+        // does not handle correctly NOT IN subqueries nested into other expressions
+        failAnalysis("Null-aware predicate subqueries are not currently supported in DELETE")
+
+      case UpdateTable(r, _, Some(condition)) if hasNullAwarePredicateWithinNot(condition) && isIcebergRelation(r) =>
+        // this limitation is present since SPARK-25154 fix is not yet available
+        // we use Not(EqualsNullSafe(cond, true)) when processing records that did not match
+        // the update condition but were present in files we are overwriting
+        // such conditions are rewritten by Spark as an existential join and currently Spark
+        // does not handle correctly NOT IN subqueries nested into other expressions
+        failAnalysis("Null-aware predicate subqueries are not currently supported in UPDATE")
+
       case merge: MergeIntoTable if isIcebergRelation(merge.targetTable) =>
         validateMergeIntoConditions(merge)
+
       case _ => // OK
     }
   }
@@ -58,4 +78,13 @@ object MergeIntoTablePredicateCheck extends (LogicalPlan => Unit) {
         s"Found a subquery in the $condName condition: ${cond.sql}")
     }
   }
+
+  private def hasNullAwarePredicateWithinNot(cond: Expression): Boolean = {
+    cond.find {
+      case Not(expr) if expr.find(_.isInstanceOf[InSubquery]).isDefined => true
+      case _ => false
+    }.isDefined
+  }
+
+  private def failAnalysis(msg: String): Unit = throw new AnalysisException(msg)
 }
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesInRowLevelOperations.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesInRowLevelOperations.scala
index f544c02..f0794d7 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesInRowLevelOperations.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullupCorrelatedPredicatesInRowLevelOperations.scala
@@ -19,10 +19,12 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
+import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.plans.logical.DeleteFromTable
 import org.apache.spark.sql.catalyst.plans.logical.Filter
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.UpdateTable
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation
 
@@ -30,12 +32,20 @@ import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation
 object PullupCorrelatedPredicatesInRowLevelOperations extends Rule[LogicalPlan] {
   override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
     case d @ DeleteFromTable(table, Some(cond)) if SubqueryExpression.hasSubquery(cond) && isIcebergRelation(table) =>
-      // Spark pulls up correlated predicates only for UnaryNodes
-      // DeleteFromTable does not extend UnaryNode so it is ignored in that rule
-      // We have this workaround until it is fixed in Spark
-      val filter = Filter(cond, table)
-      val transformedFilter = PullupCorrelatedPredicates.apply(filter)
-      val transformedCond = transformedFilter.asInstanceOf[Filter].condition
+      val transformedCond = transformCond(table, cond)
       d.copy(condition = Some(transformedCond))
+
+    case u @ UpdateTable(table, _, Some(cond)) if SubqueryExpression.hasSubquery(cond) && isIcebergRelation(table) =>
+      val transformedCond = transformCond(table, cond)
+      u.copy(condition = Some(transformedCond))
+  }
+
+  // Spark pulls up correlated predicates only for UnaryNodes
+  // DeleteFromTable and UpdateTable do not extend UnaryNode so they are ignored in that rule
+  // We have this workaround until it is fixed in Spark
+  private def transformCond(table: LogicalPlan, cond: Expression): Expression = {
+    val filter = Filter(cond, table)
+    val transformedFilter = PullupCorrelatedPredicates.apply(filter)
+    transformedFilter.asInstanceOf[Filter].condition
   }
 }
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteUpdate.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteUpdate.scala
index ab6d524..ed966b0 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteUpdate.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteUpdate.scala
@@ -19,23 +19,28 @@
 
 package org.apache.spark.sql.catalyst.optimizer
 
-import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.Alias
+import org.apache.spark.sql.catalyst.expressions.EqualNullSafe
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.If
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.catalyst.expressions.Not
 import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
 import org.apache.spark.sql.catalyst.plans.logical.Assignment
+import org.apache.spark.sql.catalyst.plans.logical.DynamicFileFilter
 import org.apache.spark.sql.catalyst.plans.logical.Filter
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.plans.logical.Project
 import org.apache.spark.sql.catalyst.plans.logical.ReplaceData
+import org.apache.spark.sql.catalyst.plans.logical.Union
 import org.apache.spark.sql.catalyst.plans.logical.UpdateTable
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.utils.PlanUtils.isIcebergRelation
 import org.apache.spark.sql.catalyst.utils.RewriteRowLevelOperationHelper
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Implicits
+import org.apache.spark.sql.types.BooleanType
 
 case class RewriteUpdate(spark: SparkSession) extends Rule[LogicalPlan] with RewriteRowLevelOperationHelper {
 
@@ -45,7 +50,35 @@ case class RewriteUpdate(spark: SparkSession) extends Rule[LogicalPlan] with Rew
   override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
     case UpdateTable(r: DataSourceV2Relation, assignments, Some(cond))
         if isIcebergRelation(r) && SubqueryExpression.hasSubquery(cond) =>
-      throw new AnalysisException("UPDATE statements with subqueries are not currently supported")
+
+      val writeInfo = newWriteInfo(r.schema)
+      val mergeBuilder = r.table.asMergeable.newMergeBuilder("update", writeInfo)
+
+      // since we are processing matched and not matched rows using separate jobs
+      // there will be two scans but we want to execute the dynamic file filter only once
+      // so the first job uses DynamicFileFilter and the second one uses the underlying scan plan
+      // both jobs share the same SparkMergeScan instance to ensure they operate on same files
+      val matchingRowsPlanBuilder = scanRelation => Filter(cond, scanRelation)
+      val scanPlan = buildDynamicFilterScanPlan(spark, r.table, r.output, mergeBuilder, cond, matchingRowsPlanBuilder)
+      val underlyingScanPlan = scanPlan match {
+        case DynamicFileFilter(plan, _, _) => plan.clone()
+        case _ => scanPlan.clone()
+      }
+
+      // build a plan for records that match the cond and should be updated
+      val matchedRowsPlan = Filter(cond, scanPlan)
+      val updatedRowsPlan = buildUpdateProjection(r, matchedRowsPlan, assignments)
+
+      // build a plan for records that did not match the cond but had to be copied over
+      val remainingRowFilter = Not(EqualNullSafe(cond, Literal(true, BooleanType)))
+      val remainingRowsPlan = Filter(remainingRowFilter, Project(r.output, underlyingScanPlan))
+
+      // new state is a union of updated and copied over records
+      val updatePlan = Union(updatedRowsPlan, remainingRowsPlan)
+
+      val mergeWrite = mergeBuilder.asWriteBuilder.buildForBatch()
+      val writePlan = buildWritePlan(updatePlan, r.table)
+      ReplaceData(r, mergeWrite, writePlan)
 
     case UpdateTable(r: DataSourceV2Relation, assignments, Some(cond)) if isIcebergRelation(r) =>
       val writeInfo = newWriteInfo(r.schema)
@@ -65,7 +98,7 @@ case class RewriteUpdate(spark: SparkSession) extends Rule[LogicalPlan] with Rew
       relation: DataSourceV2Relation,
       scanPlan: LogicalPlan,
       assignments: Seq[Assignment],
-      cond: Expression): LogicalPlan = {
+      cond: Expression = Literal.TrueLiteral): LogicalPlan = {
 
     // this method relies on the fact that the assignments have been aligned before
     require(relation.output.size == assignments.size, "assignments must be aligned")
@@ -74,12 +107,14 @@ case class RewriteUpdate(spark: SparkSession) extends Rule[LogicalPlan] with Rew
     val assignedExprs = assignments.map(_.value)
     val updatedExprs = assignedExprs.zip(relation.output).map { case (assignedExpr, attr) =>
       // use semanticEquals to avoid unnecessary if expressions as we may run after operator optimization
-      val updatedExpr = if (attr.semanticEquals(assignedExpr)) {
+      if (attr.semanticEquals(assignedExpr)) {
         attr
+      } else if (cond == Literal.TrueLiteral) {
+        Alias(assignedExpr, attr.name)()
       } else {
-        If(cond, assignedExpr, attr)
+        val updatedExpr = If(cond, assignedExpr, attr)
+        Alias(updatedExpr, attr.name)()
       }
-      Alias(updatedExpr, attr.name)()
     }
 
     Project(updatedExprs, scanPlan)
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
index e9bc96f..863844a 100644
--- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
@@ -30,6 +30,7 @@ import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.spark.SparkCatalog;
 import org.apache.iceberg.spark.SparkSessionCatalog;
 import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
@@ -152,6 +153,10 @@ public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTes
     ds.createOrReplaceTempView(name);
   }
 
+  protected <T> void createOrReplaceView(String name, List<T> data, Encoder<T> encoder) {
+    spark.createDataset(data, encoder).createOrReplaceTempView(name);
+  }
+
   private Dataset<Row> toDS(String schema, String jsonData) {
     List<String> jsonRows = Arrays.stream(jsonData.split("\n"))
         .filter(str -> str.trim().length() > 0)
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
index c3c6eaf..846a275 100644
--- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
@@ -40,7 +40,6 @@ import org.apache.iceberg.relocated.com.google.common.util.concurrent.MoreExecut
 import org.apache.spark.SparkException;
 import org.apache.spark.sql.AnalysisException;
 import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
@@ -475,7 +474,7 @@ public abstract class TestDelete extends SparkRowLevelOperationsTestBase {
     createOrReplaceView("deleted_id", Arrays.asList(-1, -2, null), Encoders.INT());
 
     AssertHelpers.assertThrows("Should complain about NOT IN subquery",
-        AnalysisException.class, "Null-aware predicate sub-queries are not currently supported",
+        AnalysisException.class, "Null-aware predicate subqueries are not currently supported",
         () -> sql("DELETE FROM %s WHERE id NOT IN (SELECT * FROM deleted_id)", tableName));
   }
 
@@ -737,10 +736,6 @@ public abstract class TestDelete extends SparkRowLevelOperationsTestBase {
     initTable();
   }
 
-  protected <T> void createOrReplaceView(String name, List<T> data, Encoder<T> encoder) {
-    spark.createDataset(data, encoder).createOrReplaceTempView(name);
-  }
-
   protected void append(Employee... employees) throws NoSuchTableException {
     List<Employee> input = Arrays.asList(employees);
     Dataset<Row> inputDF = spark.createDataFrame(input, Employee.class);
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java
index 441773d..3ff3e43 100644
--- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestUpdate.java
@@ -20,6 +20,7 @@
 package org.apache.iceberg.spark.extensions;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ExecutionException;
@@ -71,6 +72,9 @@ public abstract class TestUpdate extends SparkRowLevelOperationsTestBase {
   @After
   public void removeTables() {
     sql("DROP TABLE IF EXISTS %s", tableName);
+    sql("DROP TABLE IF EXISTS updated_id");
+    sql("DROP TABLE IF EXISTS updated_dep");
+    sql("DROP TABLE IF EXISTS deleted_employee");
   }
 
   @Test
@@ -520,6 +524,281 @@ public abstract class TestUpdate extends SparkRowLevelOperationsTestBase {
   }
 
   @Test
+  public void testUpdateWithInSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": null, \"dep\": \"hr\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(0, 1, null), Encoders.INT());
+    createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+    sql("UPDATE %s SET id = -1 WHERE " +
+        "id IN (SELECT * FROM updated_id) AND " +
+        "dep IN (SELECT * from updated_dep)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+    sql("UPDATE %s SET id = 5 WHERE id IS NULL OR id IN (SELECT value + 1 FROM updated_id)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+    append(tableName,
+        "{ \"id\": null, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hr\" }");
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(2, "hr"), row(5, "hardware"), row(5, "hr"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+
+    sql("UPDATE %s SET id = 10 WHERE id IN (SELECT value + 2 FROM updated_id) AND dep = 'hr'", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(5, "hardware"), row(5, "hr"), row(10, "hr"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+  }
+
+  @Test
+  public void testUpdateWithInSubqueryAndDynamicFileFiltering() {
+    createAndInitTable("id INT, dep STRING");
+    sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+    sql("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName);
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 3, \"dep\": \"hr\" }");
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(-1, 2), Encoders.INT());
+
+    sql("UPDATE %s SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+    Snapshot currentSnapshot = table.currentSnapshot();
+    validateSnapshot(currentSnapshot, "overwrite", "1", "1", "1");
+
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hardware"), row(1, "hardware"), row(1, "hr"), row(3, "hr")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+  }
+
+  @Test
+  public void testUpdateWithSelfSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hr\" }");
+
+    sql("UPDATE %s SET dep = 'x' WHERE id IN (SELECT id + 1 FROM %s)", tableName, tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(1, "hr"), row(2, "x")),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+
+    sql("UPDATE %s SET dep = 'y' WHERE " +
+        "id = (SELECT count(*) FROM (SELECT DISTINCT id FROM %s) AS t)", tableName, tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(1, "hr"), row(2, "y")),
+        sql("SELECT * FROM %s ORDER BY id", tableName));
+
+    sql("UPDATE %s SET id = (SELECT id - 2 FROM %s WHERE id = 1)", tableName, tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(-1, "y")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+  }
+
+  @Test
+  public void testUpdateWithMultiColumnInSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": null, \"dep\": \"hr\" }");
+
+    List<Employee> deletedEmployees = Arrays.asList(new Employee(null, "hr"), new Employee(1, "hr"));
+    createOrReplaceView("deleted_employee", deletedEmployees, Encoders.bean(Employee.class));
+
+    sql("UPDATE %s SET dep = 'x', id = -1 WHERE (id, dep) IN (SELECT id, dep FROM deleted_employee)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+  }
+
+  @Ignore // TODO: not supported since SPARK-25154 fix is not yet available
+  public void testUpdateWithNotInSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": null, \"dep\": \"hr\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+    createOrReplaceView("updated_dep", Arrays.asList("software", "hr"), Encoders.STRING());
+
+    // the file filter subquery (nested loop lef-anti join) returns 0 records
+    sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+    sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id WHERE value IS NOT NULL)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(-1, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+
+    sql("UPDATE %s SET id = 5 WHERE id NOT IN (SELECT * FROM updated_id) OR dep IN ('software', 'hr')", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hardware"), row(5, "hr"), row(5, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST, dep", tableName));
+  }
+
+  @Test
+  public void testUpdateWithNotInSubqueryNotSupported() {
+    createAndInitTable("id INT, dep STRING");
+
+    createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+
+    AssertHelpers.assertThrows("Should complain about NOT IN subquery",
+        AnalysisException.class, "Null-aware predicate subqueries are not currently supported",
+        () -> sql("UPDATE %s SET id = -1 WHERE id NOT IN (SELECT * FROM updated_id)", tableName));
+  }
+
+  @Test
+  public void testUpdateWithExistSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": null, \"dep\": \"hr\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+    createOrReplaceView("updated_dep", Arrays.asList("hr", null), Encoders.STRING());
+
+    sql("UPDATE %s t SET id = -1 WHERE EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(1, "hr"), row(2, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+    sql("UPDATE %s t SET dep = 'x', id = -1 WHERE " +
+        "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "x"), row(2, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+
+    sql("UPDATE %s t SET id = -2 WHERE " +
+        "EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " +
+        "t.id IS NULL", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-2, "hr"), row(-2, "x"), row(2, "hardware")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+    sql("UPDATE %s t SET id = 1 WHERE " +
+        "EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " +
+        "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-2, "x"), row(1, "hr"), row(2, "hardware")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+  }
+
+  @Test
+  public void testUpdateWithNotExistsSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": null, \"dep\": \"hr\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(-1, -2, null), Encoders.INT());
+    createOrReplaceView("updated_dep", Arrays.asList("hr", "software"), Encoders.STRING());
+
+    sql("UPDATE %s t SET id = -1 WHERE NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value + 2)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(1, "hr")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+    sql("UPDATE %s t SET id = 5 WHERE " +
+        "NOT EXISTS (SELECT 1 FROM updated_id u WHERE t.id = u.value) OR " +
+        "t.id = 1", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(5, "hr")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+    sql("UPDATE %s t SET id = 10 WHERE " +
+        "NOT EXISTS (SELECT 1 FROM updated_id ui WHERE t.id = ui.value) AND " +
+        "EXISTS (SELECT 1 FROM updated_dep ud WHERE t.dep = ud.value)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hardware"), row(-1, "hr"), row(10, "hr")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+  }
+
+  @Test
+  public void testUpdateWithScalarSubquery() {
+    createAndInitTable("id INT, dep STRING");
+
+    append(tableName,
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hardware\" }\n" +
+        "{ \"id\": null, \"dep\": \"hr\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(1, 100, null), Encoders.INT());
+
+    sql("UPDATE %s SET id = -1 WHERE id <= (SELECT min(value) FROM updated_id)", tableName);
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(-1, "hr"), row(2, "hardware"), row(null, "hr")),
+        sql("SELECT * FROM %s ORDER BY id ASC NULLS LAST", tableName));
+  }
+
+  @Test
+  public void testUpdateThatRequiresGroupingBeforeWrite() {
+    createAndInitTable("id INT, dep STRING");
+    sql("ALTER TABLE %s ADD PARTITION FIELD dep", tableName);
+
+    append(tableName,
+        "{ \"id\": 0, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hr\" }");
+
+    append(tableName,
+        "{ \"id\": 0, \"dep\": \"ops\" }\n" +
+        "{ \"id\": 1, \"dep\": \"ops\" }\n" +
+        "{ \"id\": 2, \"dep\": \"ops\" }");
+
+    append(tableName,
+        "{ \"id\": 0, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 1, \"dep\": \"hr\" }\n" +
+        "{ \"id\": 2, \"dep\": \"hr\" }");
+
+    append(tableName,
+        "{ \"id\": 0, \"dep\": \"ops\" }\n" +
+        "{ \"id\": 1, \"dep\": \"ops\" }\n" +
+        "{ \"id\": 2, \"dep\": \"ops\" }");
+
+    createOrReplaceView("updated_id", Arrays.asList(1, 100), Encoders.INT());
+
+    String originalNumOfShufflePartitions = spark.conf().get("spark.sql.shuffle.partitions");
+    try {
+      // set the num of shuffle partitions to 1 to ensure we have only 1 writing task
+      spark.conf().set("spark.sql.shuffle.partitions", "1");
+
+      sql("UPDATE %s t SET id = -1 WHERE id IN (SELECT * FROM updated_id)", tableName);
+      Assert.assertEquals("Should have expected num of rows", 12L, spark.table(tableName).count());
+    } finally {
+      spark.conf().set("spark.sql.shuffle.partitions", originalNumOfShufflePartitions);
+    }
+  }
+
+  @Test
   public void testUpdateWithInvalidUpdates() {
     createAndInitTable("id INT, a ARRAY<STRUCT<c1:INT,c2:INT>>, m MAP<STRING,STRING>");
 
diff --git a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMergeScan.java b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMergeScan.java
index 9892ec1..ac2ee40 100644
--- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMergeScan.java
+++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkMergeScan.java
@@ -121,7 +121,7 @@ class SparkMergeScan extends SparkBatchScan implements SupportsFileFilter {
   }
 
   // should be accessible to the write
-  List<FileScanTask> files() {
+  synchronized List<FileScanTask> files() {
     if (files == null) {
       TableScan scan = table
           .newScan()
@@ -148,7 +148,7 @@ class SparkMergeScan extends SparkBatchScan implements SupportsFileFilter {
   }
 
   @Override
-  protected List<CombinedScanTask> tasks() {
+  protected synchronized List<CombinedScanTask> tasks() {
     if (tasks == null) {
       CloseableIterable<FileScanTask> splitFiles = TableScanUtil.splitFiles(
           CloseableIterable.withNoopClose(files()),