You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2022/10/05 00:46:30 UTC

[iceberg] branch master updated: Spark 3.3: Fix nullability in merge-on-read projections (#5880)

This is an automated email from the ASF dual-hosted git repository.

aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git


The following commit(s) were added to refs/heads/master by this push:
     new 3170e6ebc6 Spark 3.3: Fix nullability in merge-on-read projections (#5880)
3170e6ebc6 is described below

commit 3170e6ebc65b0c126e57956fd2df5efc59ff88ed
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Tue Oct 4 17:46:22 2022 -0700

    Spark 3.3: Fix nullability in merge-on-read projections (#5880)
---
 .../catalyst/analysis/RewriteMergeIntoTable.scala  | 57 +++++++++++++++++++++-
 .../analysis/RewriteRowLevelIcebergCommand.scala   | 25 +++-------
 .../spark/extensions/TestCopyOnWriteMerge.java     | 35 -------------
 .../apache/iceberg/spark/extensions/TestMerge.java | 32 ++++++++++++
 4 files changed, 95 insertions(+), 54 deletions(-)

diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index ff433831aa..2e720bdd44 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -20,6 +20,7 @@
 package org.apache.spark.sql.catalyst.analysis
 
 import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.ProjectingInternalRow
 import org.apache.spark.sql.catalyst.expressions.Alias
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -54,6 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData
 import org.apache.spark.sql.catalyst.plans.logical.UpdateAction
 import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
 import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
+import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
 import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
 import org.apache.spark.sql.connector.expressions.FieldReference
 import org.apache.spark.sql.connector.expressions.NamedReference
@@ -62,6 +64,8 @@ import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
 import org.apache.spark.sql.connector.write.RowLevelOperationTable
 import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
 import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.CaseInsensitiveStringMap
 
 /**
@@ -297,7 +301,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
 
     // build a plan to write the row delta to the table
     val writeRelation = relation.copy(table = operationTable)
-    val projections = buildWriteDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
+    val projections = buildMergeDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
     WriteDelta(writeRelation, mergeRows, relation, projections)
   }
 
@@ -384,4 +388,55 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
   private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = {
     ExtendedV2ExpressionUtils.resolveRef[AttributeReference](ref, plan)
   }
+
+  private def buildMergeDeltaProjections(
+      mergeRows: MergeRows,
+      rowAttrs: Seq[Attribute],
+      rowIdAttrs: Seq[Attribute],
+      metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
+
+    val outputAttrs = mergeRows.output
+
+    val outputs = mergeRows.matchedOutputs ++ mergeRows.notMatchedOutputs
+    val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal(DELETE_OPERATION))
+    val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal(INSERT_OPERATION))
+
+    val rowProjection = if (rowAttrs.nonEmpty) {
+      Some(newLazyProjection(insertAndUpdateOutputs, outputAttrs, rowAttrs))
+    } else {
+      None
+    }
+
+    val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, outputAttrs, rowIdAttrs)
+
+    val metadataProjection = if (metadataAttrs.nonEmpty) {
+      Some(newLazyProjection(updateAndDeleteOutputs, outputAttrs, metadataAttrs))
+    } else {
+      None
+    }
+
+    WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
+  }
+
+  // the projection is done by name, ignoring expr IDs
+  private def newLazyProjection(
+      outputs: Seq[Seq[Expression]],
+      outputAttrs: Seq[Attribute],
+      projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
+
+    val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name))
+
+    val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) =>
+      // output attr is nullable if at least one action may produce null for that attr
+      // but row ID and metadata attrs are projected only in update/delete actions and
+      // row attrs are projected only in insert/update actions
+      // that's why the projection schema must rely only on relevant action outputs
+      // instead of blindly inheriting the output attr nullability
+      val nullable = outputs.exists(output => output(ordinal).nullable)
+      StructField(attr.name, attr.dataType, nullable, attr.metadata)
+    }
+    val schema = StructType(structFields)
+
+    ProjectingInternalRow(schema, projectedOrdinals)
+  }
 }
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
index ec3b9576d0..c7378f71e6 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
@@ -40,20 +40,15 @@ trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand {
       metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
 
     val rowProjection = if (rowAttrs.nonEmpty) {
-      Some(newLazyProjection(plan, rowAttrs, usePlanTypes = true))
+      Some(newLazyProjection(plan, rowAttrs))
     } else {
       None
     }
 
-    // in MERGE, the plan may contain both delete and insert records that may affect
-    // the nullability of metadata columns (e.g. metadata columns for new records are always null)
-    // since metadata columns are never passed with new records to insert,
-    // use the actual metadata column types instead of the ones present in the plan
-
-    val rowIdProjection = newLazyProjection(plan, rowIdAttrs, usePlanTypes = false)
+    val rowIdProjection = newLazyProjection(plan, rowIdAttrs)
 
     val metadataProjection = if (metadataAttrs.nonEmpty) {
-      Some(newLazyProjection(plan, metadataAttrs, usePlanTypes = false))
+      Some(newLazyProjection(plan, metadataAttrs))
     } else {
       None
     }
@@ -64,17 +59,11 @@ trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand {
   // the projection is done by name, ignoring expr IDs
   private def newLazyProjection(
       plan: LogicalPlan,
-      attrs: Seq[Attribute],
-      usePlanTypes: Boolean): ProjectingInternalRow = {
+      projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
 
-    val colOrdinals = attrs.map(attr => plan.output.indexWhere(_.name == attr.name))
-    val schema = if (usePlanTypes) {
-      val planAttrs = colOrdinals.map(plan.output(_))
-      StructType.fromAttributes(planAttrs)
-    } else {
-      StructType.fromAttributes(attrs)
-    }
-    ProjectingInternalRow(schema, colOrdinals)
+    val projectedOrdinals = projectedAttrs.map(attr => plan.output.indexWhere(_.name == attr.name))
+    val schema = StructType.fromAttributes(projectedOrdinals.map(plan.output(_)))
+    ProjectingInternalRow(schema, projectedOrdinals)
   }
 
   protected def resolveRowIdAttrs(
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
index 8ee62ad2f6..27cbd1a9d5 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
@@ -20,9 +20,7 @@ package org.apache.iceberg.spark.extensions;
 
 import java.util.Map;
 import org.apache.iceberg.TableProperties;
-import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
-import org.junit.Test;
 
 public class TestCopyOnWriteMerge extends TestMerge {
 
@@ -40,37 +38,4 @@ public class TestCopyOnWriteMerge extends TestMerge {
   protected Map<String, String> extraTableProperties() {
     return ImmutableMap.of(TableProperties.MERGE_MODE, "copy-on-write");
   }
-
-  @Test
-  public void testMergeWithTableWithNonNullableColumn() {
-    createAndInitTable(
-        "id INT NOT NULL, dep STRING",
-        "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
-
-    createOrReplaceView(
-        "source",
-        "id INT NOT NULL, dep STRING",
-        "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n"
-            + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n"
-            + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
-
-    sql(
-        "MERGE INTO %s AS t USING source AS s "
-            + "ON t.id == s.id "
-            + "WHEN MATCHED AND t.id = 1 THEN "
-            + "  UPDATE SET * "
-            + "WHEN MATCHED AND t.id = 6 THEN "
-            + "  DELETE "
-            + "WHEN NOT MATCHED AND s.id = 2 THEN "
-            + "  INSERT *",
-        tableName);
-
-    ImmutableList<Object[]> expectedRows =
-        ImmutableList.of(
-            row(1, "emp-id-1"), // updated
-            row(2, "emp-id-2") // new
-            );
-    assertEquals(
-        "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
-  }
 }
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index c485dbfe2f..e7944eea74 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -1687,6 +1687,38 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
         "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
   }
 
+  @Test
+  public void testMergeWithTableWithNonNullableColumn() {
+    createAndInitTable(
+        "id INT NOT NULL, dep STRING",
+        "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+    createOrReplaceView(
+        "source",
+        "id INT NOT NULL, dep STRING",
+        "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n"
+            + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n"
+            + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+    sql(
+        "MERGE INTO %s AS t USING source AS s "
+            + "ON t.id == s.id "
+            + "WHEN MATCHED AND t.id = 1 THEN "
+            + "  UPDATE SET * "
+            + "WHEN MATCHED AND t.id = 6 THEN "
+            + "  DELETE "
+            + "WHEN NOT MATCHED AND s.id = 2 THEN "
+            + "  INSERT *",
+        tableName);
+
+    ImmutableList<Object[]> expectedRows =
+        ImmutableList.of(
+            row(1, "emp-id-1"), // updated
+            row(2, "emp-id-2")); // new
+    assertEquals(
+        "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+  }
+
   @Test
   public void testMergeWithNonExistingColumns() {
     createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");