You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2021/03/24 18:41:17 UTC

[iceberg] 02/18: Spark: Refresh relation cache in DELETE and MERGE (#2154)

This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch 0.11.x
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit 9b49c2175aabc730243e07a4041fb158554fe756
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Tue Jan 26 12:42:56 2021 -0800

    Spark: Refresh relation cache in DELETE and MERGE (#2154)
---
 .../v2/ExtendedDataSourceV2Strategy.scala          |  9 ++++--
 .../execution/datasources/v2/ReplaceDataExec.scala |  9 ++++--
 .../SparkRowLevelOperationsTestBase.java           | 24 +++++++++++++--
 .../iceberg/spark/extensions/TestDelete.java       | 35 ++++++++++++++++++++++
 .../apache/iceberg/spark/extensions/TestMerge.java | 29 ++++++++++++++++++
 5 files changed, 100 insertions(+), 6 deletions(-)

diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
index 3a8072e..ac2fc2e 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExtendedDataSourceV2Strategy.scala
@@ -25,6 +25,7 @@ import org.apache.iceberg.spark.SparkSessionCatalog
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.Strategy
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.NamedRelation
 import org.apache.spark.sql.catalyst.expressions.And
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
@@ -81,8 +82,8 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy {
       val batchExec = ExtendedBatchScanExec(output, scan)
       withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil
 
-    case ReplaceData(_, batchWrite, query) =>
-      ReplaceDataExec(batchWrite, planLater(query)) :: Nil
+    case ReplaceData(relation, batchWrite, query) =>
+      ReplaceDataExec(batchWrite, refreshCache(relation), planLater(query)) :: Nil
 
     case MergeInto(mergeIntoParams, output, child) =>
       MergeIntoExec(mergeIntoParams, output, planLater(child)) :: Nil
@@ -113,6 +114,10 @@ case class ExtendedDataSourceV2Strategy(spark: SparkSession) extends Strategy {
     }
   }
 
+  private def refreshCache(r: NamedRelation)(): Unit = {
+    spark.sharedState.cacheManager.recacheByPlan(spark, r)
+  }
+
   private object IcebergCatalogAndIdentifier {
     def unapply(identifier: Seq[String]): Option[(TableCatalog, Identifier)] = {
       val catalogAndIdentifier = Spark3Util.catalogAndIdentifier(spark, identifier.asJava)
diff --git a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala
index 2551a18..f26a8c7 100644
--- a/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala
+++ b/spark3-extensions/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ReplaceDataExec.scala
@@ -23,11 +23,16 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.write.BatchWrite
 import org.apache.spark.sql.execution.SparkPlan
 
-case class ReplaceDataExec(batchWrite: BatchWrite, query: SparkPlan) extends V2TableWriteExec {
+case class ReplaceDataExec(
+    batchWrite: BatchWrite,
+    refreshCache: () => Unit,
+    query: SparkPlan) extends V2TableWriteExec {
 
   override protected def run(): Seq[InternalRow] = {
     // calling prepare() ensures we execute DynamicFileFilter if present
     prepare()
-    writeWithV2(batchWrite)
+    val writtenRows = writeWithV2(batchWrite)
+    refreshCache()
+    writtenRows
   }
 }
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
index b54d24c..1d19a27 100644
--- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/SparkRowLevelOperationsTestBase.java
@@ -30,6 +30,8 @@ import org.apache.iceberg.spark.SparkCatalog;
 import org.apache.iceberg.spark.SparkSessionCatalog;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Encoders;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
 import org.junit.Assert;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -110,8 +112,21 @@ public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTes
   }
 
   protected void createAndInitTable(String schema) {
+    createAndInitTable(schema, null);
+  }
+
+  protected void createAndInitTable(String schema, String jsonData) {
     sql("CREATE TABLE %s (%s) USING iceberg", tableName, schema);
     initTable();
+
+    if (jsonData != null) {
+      try {
+        Dataset<Row> ds = toDS(schema, jsonData);
+        ds.writeTo(tableName).append();
+      } catch (NoSuchTableException e) {
+        throw new RuntimeException("Failed to write data", e);
+      }
+    }
   }
 
   protected void createOrReplaceView(String name, String jsonData) {
@@ -119,15 +134,20 @@ public abstract class SparkRowLevelOperationsTestBase extends SparkExtensionsTes
   }
 
   protected void createOrReplaceView(String name, String schema, String jsonData) {
+    Dataset<Row> ds = toDS(schema, jsonData);
+    ds.createOrReplaceTempView(name);
+  }
+
+  private Dataset<Row> toDS(String schema, String jsonData) {
     List<String> jsonRows = Arrays.stream(jsonData.split("\n"))
         .filter(str -> str.trim().length() > 0)
         .collect(Collectors.toList());
     Dataset<String> jsonDS = spark.createDataset(jsonRows, Encoders.STRING());
 
     if (schema != null) {
-      spark.read().schema(schema).json(jsonDS).createOrReplaceTempView(name);
+      return spark.read().schema(schema).json(jsonDS);
     } else {
-      spark.read().json(jsonDS).createOrReplaceTempView(name);
+      return spark.read().json(jsonDS);
     }
   }
 }
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
index 31986db..3f3b010 100644
--- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestDelete.java
@@ -657,6 +657,41 @@ public abstract class TestDelete extends SparkRowLevelOperationsTestBase {
     Assert.assertTrue("Timeout", executorService.awaitTermination(2, TimeUnit.MINUTES));
   }
 
+  @Test
+  public void testDeleteRefreshesRelationCache() throws NoSuchTableException {
+    createAndInitPartitionedTable();
+
+    append(new Employee(1, "hr"), new Employee(3, "hr"));
+    append(new Employee(1, "hardware"), new Employee(2, "hardware"));
+
+    Dataset<Row> query = spark.sql("SELECT * FROM " + tableName + " WHERE id = 1");
+    query.createOrReplaceTempView("tmp");
+
+    spark.sql("CACHE TABLE tmp");
+
+    assertEquals("View should have correct data",
+        ImmutableList.of(row(1, "hardware"), row(1, "hr")),
+        sql("SELECT * FROM tmp ORDER BY id, dep"));
+
+    sql("DELETE FROM %s WHERE id = 1", tableName);
+
+    Table table = validationCatalog.loadTable(tableIdent);
+    Assert.assertEquals("Should have 3 snapshots", 3, Iterables.size(table.snapshots()));
+
+    Snapshot currentSnapshot = table.currentSnapshot();
+    validateSnapshot(currentSnapshot, "overwrite", "2", "2", "2");
+
+    assertEquals("Should have expected rows",
+        ImmutableList.of(row(2, "hardware"), row(3, "hr")),
+        sql("SELECT * FROM %s ORDER BY id, dep", tableName));
+
+    assertEquals("Should refresh the relation cache",
+        ImmutableList.of(),
+        sql("SELECT * FROM tmp ORDER BY id, dep"));
+
+    spark.sql("UNCACHE TABLE tmp");
+  }
+
   // TODO: multiple stripes for ORC
 
   protected void validateSnapshot(Snapshot snapshot, String operation, String changedPartitionCount,
diff --git a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index c5dc5aa..4a004f5 100644
--- a/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark3-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -21,8 +21,11 @@ package org.apache.iceberg.spark.extensions;
 
 import java.util.Map;
 import org.apache.iceberg.AssertHelpers;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.spark.sql.AnalysisException;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
 import org.junit.After;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -50,6 +53,32 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
   // TODO: tests for subqueries in conditions
 
   @Test
+  public void testMergeRefreshesRelationCache() {
+    createAndInitTable("id INT, name STRING", "{ \"id\": 1, \"name\": \"n1\" }");
+    createOrReplaceView("source", "{ \"id\": 1, \"name\": \"n2\" }");
+
+    Dataset<Row> query = spark.sql("SELECT name FROM " + tableName);
+    query.createOrReplaceTempView("tmp");
+
+    spark.sql("CACHE TABLE tmp");
+
+    assertEquals("View should have correct data",
+        ImmutableList.of(row("n1")),
+        sql("SELECT * FROM tmp"));
+
+    sql("MERGE INTO %s t USING source s " +
+        "ON t.id == s.id " +
+        "WHEN MATCHED THEN " +
+        "  UPDATE SET t.name = s.name", tableName);
+
+    assertEquals("View should have correct data",
+        ImmutableList.of(row("n2")),
+        sql("SELECT * FROM tmp"));
+
+    spark.sql("UNCACHE TABLE tmp");
+  }
+
+  @Test
   public void testMergeWithNonExistingColumns() {
     createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");
     createOrReplaceView("source", "{ \"c1\": -100, \"c2\": -200 }");