You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iceberg.apache.org by ao...@apache.org on 2022/10/05 00:46:30 UTC
[iceberg] branch master updated: Spark 3.3: Fix nullability in merge-on-read projections (#5880)
This is an automated email from the ASF dual-hosted git repository.
aokolnychyi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 3170e6ebc6 Spark 3.3: Fix nullability in merge-on-read projections (#5880)
3170e6ebc6 is described below
commit 3170e6ebc65b0c126e57956fd2df5efc59ff88ed
Author: Anton Okolnychyi <ao...@apple.com>
AuthorDate: Tue Oct 4 17:46:22 2022 -0700
Spark 3.3: Fix nullability in merge-on-read projections (#5880)
---
.../catalyst/analysis/RewriteMergeIntoTable.scala | 57 +++++++++++++++++++++-
.../analysis/RewriteRowLevelIcebergCommand.scala | 25 +++-------
.../spark/extensions/TestCopyOnWriteMerge.java | 35 -------------
.../apache/iceberg/spark/extensions/TestMerge.java | 32 ++++++++++++
4 files changed, 95 insertions(+), 54 deletions(-)
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
index ff433831aa..2e720bdd44 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteMergeIntoTable.scala
@@ -20,6 +20,7 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.ProjectingInternalRow
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.AttributeReference
@@ -54,6 +55,7 @@ import org.apache.spark.sql.catalyst.plans.logical.ReplaceIcebergData
import org.apache.spark.sql.catalyst.plans.logical.UpdateAction
import org.apache.spark.sql.catalyst.plans.logical.WriteDelta
import org.apache.spark.sql.catalyst.util.RowDeltaUtils._
+import org.apache.spark.sql.catalyst.util.WriteDeltaProjections
import org.apache.spark.sql.connector.catalog.SupportsRowLevelOperations
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.expressions.NamedReference
@@ -62,6 +64,8 @@ import org.apache.spark.sql.connector.write.RowLevelOperation.Command.MERGE
import org.apache.spark.sql.connector.write.RowLevelOperationTable
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
@@ -297,7 +301,7 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
// build a plan to write the row delta to the table
val writeRelation = relation.copy(table = operationTable)
- val projections = buildWriteDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
+ val projections = buildMergeDeltaProjections(mergeRows, rowAttrs, rowIdAttrs, metadataAttrs)
WriteDelta(writeRelation, mergeRows, relation, projections)
}
@@ -384,4 +388,55 @@ object RewriteMergeIntoTable extends RewriteRowLevelIcebergCommand {
private def resolveAttrRef(ref: NamedReference, plan: LogicalPlan): AttributeReference = {
ExtendedV2ExpressionUtils.resolveRef[AttributeReference](ref, plan)
}
+
+ private def buildMergeDeltaProjections(
+ mergeRows: MergeRows,
+ rowAttrs: Seq[Attribute],
+ rowIdAttrs: Seq[Attribute],
+ metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
+
+ val outputAttrs = mergeRows.output
+
+ val outputs = mergeRows.matchedOutputs ++ mergeRows.notMatchedOutputs
+ val insertAndUpdateOutputs = outputs.filterNot(_.head == Literal(DELETE_OPERATION))
+ val updateAndDeleteOutputs = outputs.filterNot(_.head == Literal(INSERT_OPERATION))
+
+ val rowProjection = if (rowAttrs.nonEmpty) {
+ Some(newLazyProjection(insertAndUpdateOutputs, outputAttrs, rowAttrs))
+ } else {
+ None
+ }
+
+ val rowIdProjection = newLazyProjection(updateAndDeleteOutputs, outputAttrs, rowIdAttrs)
+
+ val metadataProjection = if (metadataAttrs.nonEmpty) {
+ Some(newLazyProjection(updateAndDeleteOutputs, outputAttrs, metadataAttrs))
+ } else {
+ None
+ }
+
+ WriteDeltaProjections(rowProjection, rowIdProjection, metadataProjection)
+ }
+
+ // the projection is done by name, ignoring expr IDs
+ private def newLazyProjection(
+ outputs: Seq[Seq[Expression]],
+ outputAttrs: Seq[Attribute],
+ projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
+
+ val projectedOrdinals = projectedAttrs.map(attr => outputAttrs.indexWhere(_.name == attr.name))
+
+ val structFields = projectedAttrs.zip(projectedOrdinals).map { case (attr, ordinal) =>
+ // output attr is nullable if at least one action may produce null for that attr
+ // but row ID and metadata attrs are projected only in update/delete actions and
+ // row attrs are projected only in insert/update actions
+ // that's why the projection schema must rely only on relevant action outputs
+ // instead of blindly inheriting the output attr nullability
+ val nullable = outputs.exists(output => output(ordinal).nullable)
+ StructField(attr.name, attr.dataType, nullable, attr.metadata)
+ }
+ val schema = StructType(structFields)
+
+ ProjectingInternalRow(schema, projectedOrdinals)
+ }
}
diff --git a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
index ec3b9576d0..c7378f71e6 100644
--- a/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
+++ b/spark/v3.3/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteRowLevelIcebergCommand.scala
@@ -40,20 +40,15 @@ trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand {
metadataAttrs: Seq[Attribute]): WriteDeltaProjections = {
val rowProjection = if (rowAttrs.nonEmpty) {
- Some(newLazyProjection(plan, rowAttrs, usePlanTypes = true))
+ Some(newLazyProjection(plan, rowAttrs))
} else {
None
}
- // in MERGE, the plan may contain both delete and insert records that may affect
- // the nullability of metadata columns (e.g. metadata columns for new records are always null)
- // since metadata columns are never passed with new records to insert,
- // use the actual metadata column types instead of the ones present in the plan
-
- val rowIdProjection = newLazyProjection(plan, rowIdAttrs, usePlanTypes = false)
+ val rowIdProjection = newLazyProjection(plan, rowIdAttrs)
val metadataProjection = if (metadataAttrs.nonEmpty) {
- Some(newLazyProjection(plan, metadataAttrs, usePlanTypes = false))
+ Some(newLazyProjection(plan, metadataAttrs))
} else {
None
}
@@ -64,17 +59,11 @@ trait RewriteRowLevelIcebergCommand extends RewriteRowLevelCommand {
// the projection is done by name, ignoring expr IDs
private def newLazyProjection(
plan: LogicalPlan,
- attrs: Seq[Attribute],
- usePlanTypes: Boolean): ProjectingInternalRow = {
+ projectedAttrs: Seq[Attribute]): ProjectingInternalRow = {
- val colOrdinals = attrs.map(attr => plan.output.indexWhere(_.name == attr.name))
- val schema = if (usePlanTypes) {
- val planAttrs = colOrdinals.map(plan.output(_))
- StructType.fromAttributes(planAttrs)
- } else {
- StructType.fromAttributes(attrs)
- }
- ProjectingInternalRow(schema, colOrdinals)
+ val projectedOrdinals = projectedAttrs.map(attr => plan.output.indexWhere(_.name == attr.name))
+ val schema = StructType.fromAttributes(projectedOrdinals.map(plan.output(_)))
+ ProjectingInternalRow(schema, projectedOrdinals)
}
protected def resolveRowIdAttrs(
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
index 8ee62ad2f6..27cbd1a9d5 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestCopyOnWriteMerge.java
@@ -20,9 +20,7 @@ package org.apache.iceberg.spark.extensions;
import java.util.Map;
import org.apache.iceberg.TableProperties;
-import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
-import org.junit.Test;
public class TestCopyOnWriteMerge extends TestMerge {
@@ -40,37 +38,4 @@ public class TestCopyOnWriteMerge extends TestMerge {
protected Map<String, String> extraTableProperties() {
return ImmutableMap.of(TableProperties.MERGE_MODE, "copy-on-write");
}
-
- @Test
- public void testMergeWithTableWithNonNullableColumn() {
- createAndInitTable(
- "id INT NOT NULL, dep STRING",
- "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
-
- createOrReplaceView(
- "source",
- "id INT NOT NULL, dep STRING",
- "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n"
- + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n"
- + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
-
- sql(
- "MERGE INTO %s AS t USING source AS s "
- + "ON t.id == s.id "
- + "WHEN MATCHED AND t.id = 1 THEN "
- + " UPDATE SET * "
- + "WHEN MATCHED AND t.id = 6 THEN "
- + " DELETE "
- + "WHEN NOT MATCHED AND s.id = 2 THEN "
- + " INSERT *",
- tableName);
-
- ImmutableList<Object[]> expectedRows =
- ImmutableList.of(
- row(1, "emp-id-1"), // updated
- row(2, "emp-id-2") // new
- );
- assertEquals(
- "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
- }
}
diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
index c485dbfe2f..e7944eea74 100644
--- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
+++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestMerge.java
@@ -1687,6 +1687,38 @@ public abstract class TestMerge extends SparkRowLevelOperationsTestBase {
"Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
}
+ @Test
+ public void testMergeWithTableWithNonNullableColumn() {
+ createAndInitTable(
+ "id INT NOT NULL, dep STRING",
+ "{ \"id\": 1, \"dep\": \"emp-id-one\" }\n" + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ createOrReplaceView(
+ "source",
+ "id INT NOT NULL, dep STRING",
+ "{ \"id\": 2, \"dep\": \"emp-id-2\" }\n"
+ + "{ \"id\": 1, \"dep\": \"emp-id-1\" }\n"
+ + "{ \"id\": 6, \"dep\": \"emp-id-6\" }");
+
+ sql(
+ "MERGE INTO %s AS t USING source AS s "
+ + "ON t.id == s.id "
+ + "WHEN MATCHED AND t.id = 1 THEN "
+ + " UPDATE SET * "
+ + "WHEN MATCHED AND t.id = 6 THEN "
+ + " DELETE "
+ + "WHEN NOT MATCHED AND s.id = 2 THEN "
+ + " INSERT *",
+ tableName);
+
+ ImmutableList<Object[]> expectedRows =
+ ImmutableList.of(
+ row(1, "emp-id-1"), // updated
+ row(2, "emp-id-2")); // new
+ assertEquals(
+ "Should have expected rows", expectedRows, sql("SELECT * FROM %s ORDER BY id", tableName));
+ }
+
@Test
public void testMergeWithNonExistingColumns() {
createAndInitTable("id INT, c STRUCT<n1:INT,n2:STRUCT<dn1:INT,dn2:INT>>");