You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2021/10/05 03:18:27 UTC

[spark] branch branch-3.2 updated: [SPARK-36874][SQL] DeduplicateRelations should copy dataset_id tag to avoid ambiguous self join

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 8ffe00e  [SPARK-36874][SQL] DeduplicateRelations should copy dataset_id tag to avoid ambiguous self join
8ffe00e is described below

commit 8ffe00e74521c20ff533a13e5b9bb8d135fa37a7
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Tue Oct 5 11:16:55 2021 +0800

    [SPARK-36874][SQL] DeduplicateRelations should copy dataset_id tag to avoid ambiguous self join
    
    ### What changes were proposed in this pull request?
    
    This PR fixes an issue that ambiguous self join can't be detected if the left and right DataFrame are swapped.
    This is an example.
    ```
    val df1 = Seq((1, 2, "A1"),(2, 1, "A2")).toDF("key1", "key2", "value")
    val df2 = df1.filter($"value" === "A2")
    
    df1.join(df2, df1("key1") === df2("key2")) // Ambiguous self join is detected and AnalysisException is thrown.
    
    df2.join(df1, df1("key1") === df2("key2)) // Ambiguous self join is not detected.
    ```
    
    The root cause seems that an inner function `collectConflictPlans` in `DeduplicateRelations.` doesn't copy the `dataset_id` tag when it copies a `LogicalPlan`.
    
    ### Why are the changes needed?
    
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests.
    
    Closes #34172 from sarutak/fix-deduplication-issue.
    
    Authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit fa1805db48ca53ece4cbbe42ebb2a9811a142ed2)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../catalyst/analysis/DeduplicateRelations.scala   |  50 +++++---
 .../apache/spark/sql/DataFrameSelfJoinSuite.scala  | 127 ++++++++++++++++++++-
 2 files changed, 161 insertions(+), 16 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
index 7b37891..b877c24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala
@@ -177,13 +177,16 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
 
       case oldVersion: SerializeFromObject
           if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(
-          serializer = oldVersion.serializer.map(_.newInstance()))))
+        val newVersion = oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       // Handle projects that create conflicting aliases.
       case oldVersion @ Project(projectList, _)
           if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
+        val newVersion = oldVersion.copy(projectList = newAliases(projectList))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       // Handle projects that create conflicting outer references.
       case oldVersion @ Project(projectList, _)
@@ -193,7 +196,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
           case o @ OuterReference(a) if conflictingAttributes.contains(a) => Alias(o, a.name)()
           case other => other
         }
-        Seq((oldVersion, oldVersion.copy(projectList = aliasedProjectList)))
+        val newVersion = oldVersion.copy(projectList = aliasedProjectList)
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       // We don't need to search child plan recursively if the projectList of a Project
       // is only composed of Alias and doesn't contain any conflicting attributes.
@@ -205,8 +210,9 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
 
       case oldVersion @ Aggregate(_, aggregateExpressions, _)
           if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(
-          aggregateExpressions = newAliases(aggregateExpressions))))
+        val newVersion = oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       // We don't search the child plan recursively for the same reason as the above Project.
       case _ @ Aggregate(_, aggregateExpressions, _)
@@ -215,24 +221,34 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
 
       case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
           if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+        val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
         if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+        val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion @ MapInPandas(_, output, _)
         if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+        val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion @ AttachDistributedSequence(sequenceAttr, _)
         if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())))
+        val newVersion = oldVersion.copy(sequenceAttr = sequenceAttr.newInstance())
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion: Generate
           if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
         val newOutput = oldVersion.generatorOutput.map(_.newInstance())
-        Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
+        val newVersion = oldVersion.copy(generatorOutput = newOutput)
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion: Expand
           if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
@@ -244,16 +260,22 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
             attr
           }
         }
-        Seq((oldVersion, oldVersion.copy(output = newOutput)))
+        val newVersion = oldVersion.copy(output = newOutput)
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion @ Window(windowExpressions, _, _, child)
           if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
           .nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
+        val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case oldVersion @ ScriptTransformation(_, output, _, _)
           if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
-        Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+        val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+        newVersion.copyTagsFrom(oldVersion)
+        Seq((oldVersion, newVersion))
 
       case _ => plan.children.flatMap(collectConflictPlans)
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 062404f..a0ddabc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
 import org.apache.spark.sql.expressions.Window
-import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.functions.{count, explode, sum}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.test.SQLTestData.TestData
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
 
 class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
@@ -344,4 +347,124 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
       assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
     }
   }
+
+  test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " +
+    "to avoid ambiguous self join") {
+    // Test for Project
+    val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value")
+    val df2 = df1.filter($"value" === "A2")
+    assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2")))
+    assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2")))
+
+    // Test for SerializeFromObject
+    val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF
+    val df4 = df3.filter($"_1" <=> 0)
+    assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2")))
+    assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2")))
+
+    // Test For Aggregate
+    val df5 = df1.groupBy($"key1").agg(count($"value") as "count")
+    val df6 = df5.filter($"key1" > 0)
+    assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count")))
+    assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count")))
+
+    // Test for MapInPandas
+    val mapInPandasUDF = PythonUDF("mapInPandasUDF", null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+      true)
+    val df7 = df1.mapInPandas(mapInPandasUDF)
+    val df8 = df7.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y")))
+    assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y")))
+
+    // Test for FlatMapGroupsInPandas
+    val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+      true)
+    val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF)
+    val df10 = df9.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y")))
+    assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y")))
+
+    // Test for FlatMapCoGroupsInPandas
+    val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+      true)
+    val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas(
+      df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+    val df12 = df11.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y")))
+    assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
+
+    // Test for AttachDistributedSequence
+    val df13 = df1.withSequenceColumn("seq")
+    val df14 = df13.filter($"value" === "A2")
+    assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
+    assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
+
+    // Test for Generate
+    // Ensure that the root of the plan is Generate
+    val df15 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList"))
+      .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF
+    val df16 = df15.filter($"a" > 0)
+    assertAmbiguousSelfJoin(df15.join(df16, df15("a") === df16("col")))
+    assertAmbiguousSelfJoin(df16.join(df15, df15("a") === df16("col")))
+
+    // Test for Expand
+    // Ensure that the root of the plan is Expand
+    val df17 =
+      Expand(
+        Seq(Seq($"key1".expr, $"key2".expr)),
+        Seq(
+          AttributeReference("x", IntegerType)(),
+          AttributeReference("y", IntegerType)()),
+        df1.queryExecution.logical).toDF
+    val df18 = df17.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df17.join(df18, df17("x") === df18("y")))
+    assertAmbiguousSelfJoin(df18.join(df17, df17("x") === df18("y")))
+
+    // Test for Window
+    val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b")
+    // Ensure that the root of the plan is Window
+    val df19 = WindowPlan(
+      Seq(Alias(dfWithTS("time").expr, "ts")()),
+      Seq(dfWithTS("a").expr),
+      Seq(SortOrder(dfWithTS("a").expr, Ascending)),
+      dfWithTS.queryExecution.logical).toDF
+    val df20 = df19.filter($"a" > 0)
+    assertAmbiguousSelfJoin(df19.join(df20, df19("a") === df20("b")))
+    assertAmbiguousSelfJoin(df20.join(df19, df19("a") === df20("b")))
+
+    // Test for ScriptTransformation
+    val ioSchema =
+      ScriptInputOutputSchema(
+        Seq(("TOK_TABLEROWFORMATFIELD", ","),
+          ("TOK_TABLEROWFORMATCOLLITEMS", "#"),
+          ("TOK_TABLEROWFORMATMAPKEYS", "@"),
+          ("TOK_TABLEROWFORMATNULL", "null"),
+          ("TOK_TABLEROWFORMATLINES", "\n")),
+        Seq(("TOK_TABLEROWFORMATFIELD", ","),
+          ("TOK_TABLEROWFORMATCOLLITEMS", "#"),
+          ("TOK_TABLEROWFORMATMAPKEYS", "@"),
+          ("TOK_TABLEROWFORMATNULL", "null"),
+          ("TOK_TABLEROWFORMATLINES", "\n")), None, None,
+        List.empty, List.empty, None, None, false)
+    // Ensure that the root of the plan is ScriptTransformation
+    val df21 = ScriptTransformation(
+      "cat",
+      Seq(
+        AttributeReference("x", IntegerType)(),
+        AttributeReference("y", IntegerType)()),
+      df1.queryExecution.logical,
+      ioSchema).toDF
+    val df22 = df21.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df21.join(df22, df21("x") === df22("y")))
+    assertAmbiguousSelfJoin(df22.join(df21, df21("x") === df22("y")))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org