You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/10/07 06:57:40 UTC

[spark] branch branch-3.1 updated: [SPARK-36874][SPARK-34634][SQL][3.1] ResolveReference.dedupRight should copy dataset_id tag to avoid ambiguous self join

This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 33b30aa  [SPARK-36874][SPARK-34634][SQL][3.1] ResolveReference.dedupRight should copy dataset_id tag to avoid ambiguous self join
33b30aa is described below

commit 33b30aae1196a95534e1bfd2f191f2fc5318e975
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Thu Oct 7 15:56:52 2021 +0900

    [SPARK-36874][SPARK-34634][SQL][3.1] ResolveReference.dedupRight should copy dataset_id tag to avoid ambiguous self join
    
    ### What changes were proposed in this pull request?
    
    This PR backports the change of SPARK-36874 (#34172) mainly, and SPARK-34634 (#31752) partially to care about the ambiguous self join for `ScriptTransformation`.
    This PR fixes an issue that ambiguous self join can't be detected if the left and right DataFrame are swapped.
    This is an example.
    ```
    val df1 = Seq((1, 2, "A1"),(2, 1, "A2")).toDF("key1", "key2", "value")
    val df2 = df1.filter($"value" === "A2")
    
    df1.join(df2, df1("key1") === df2("key2")) // Ambiguous self join is detected and AnalysisException is thrown.
    
    df2.join(df1, df1("key1") === df2("key2)) // Ambiguous self join is not detected.
    ```
    
    The root cause seems that an inner function `collectConflictPlans` in `ResolveReference.dedupRight.` doesn't copy the `dataset_id` tag when it copies a `LogicalPlan`.
    
    ### Why are the changes needed?
    
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    New tests.
    
    Closes #34205 from sarutak/backport-SPARK-36874.
    
    Authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
    Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |  44 ++++++--
 .../apache/spark/sql/DataFrameSelfJoinSuite.scala  | 122 ++++++++++++++++++++-
 2 files changed, 153 insertions(+), 13 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 62bfd53..5578838 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1379,13 +1379,16 @@ class Analyzer(override val catalogManager: CatalogManager)
 
         case oldVersion: SerializeFromObject
             if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(
-            serializer = oldVersion.serializer.map(_.newInstance()))))
+          val newVersion = oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         // Handle projects that create conflicting aliases.
         case oldVersion @ Project(projectList, _)
             if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
+          val newVersion = oldVersion.copy(projectList = newAliases(projectList))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         // We don't need to search child plan recursively if the projectList of a Project
         // is only composed of Alias and doesn't contain any conflicting attributes.
@@ -1397,8 +1400,9 @@ class Analyzer(override val catalogManager: CatalogManager)
 
         case oldVersion @ Aggregate(_, aggregateExpressions, _)
             if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(
-            aggregateExpressions = newAliases(aggregateExpressions))))
+          val newVersion = oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         // We don't search the child plan recursively for the same reason as the above Project.
         case _ @ Aggregate(_, aggregateExpressions, _)
@@ -1407,20 +1411,28 @@ class Analyzer(override val catalogManager: CatalogManager)
 
         case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
             if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+          val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
             if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+          val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         case oldVersion @ MapInPandas(_, output, _)
             if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+          val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         case oldVersion: Generate
             if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
           val newOutput = oldVersion.generatorOutput.map(_.newInstance())
-          Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
+          val newVersion = oldVersion.copy(generatorOutput = newOutput)
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         case oldVersion: Expand
             if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
@@ -1432,12 +1444,22 @@ class Analyzer(override val catalogManager: CatalogManager)
               attr
             }
           }
-          Seq((oldVersion, oldVersion.copy(output = newOutput)))
+          val newVersion = oldVersion.copy(output = newOutput)
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         case oldVersion @ Window(windowExpressions, _, _, child)
             if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
             .nonEmpty =>
-          Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
+          val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
+
+        case oldVersion @ ScriptTransformation(_, _, output, _, _)
+          if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
+          val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+          newVersion.copyTagsFrom(oldVersion)
+          Seq((oldVersion, newVersion))
 
         case _ => plan.children.flatMap(collectConflictPlans)
       }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 062404f..9994981 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -17,12 +17,15 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
 import org.apache.spark.sql.expressions.Window
-import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.functions.{count, explode, sum}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.test.SQLTestData.TestData
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
 
 class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
@@ -344,4 +347,119 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
       assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
     }
   }
+
+  test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " +
+    "to avoid ambiguous self join") {
+    // Test for Project
+    val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value")
+    val df2 = df1.filter($"value" === "A2")
+    assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2")))
+    assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2")))
+
+    // Test for SerializeFromObject
+    val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF
+    val df4 = df3.filter($"_1" <=> 0)
+    assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2")))
+    assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2")))
+
+    // Test For Aggregate
+    val df5 = df1.groupBy($"key1").agg(count($"value") as "count")
+    val df6 = df5.filter($"key1" > 0)
+    assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count")))
+    assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count")))
+
+    // Test for MapInPandas
+    val mapInPandasUDF = PythonUDF("mapInPandasUDF", null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+      true)
+    val df7 = df1.mapInPandas(mapInPandasUDF)
+    val df8 = df7.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y")))
+    assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y")))
+
+    // Test for FlatMapGroupsInPandas
+    val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+      true)
+    val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF)
+    val df10 = df9.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y")))
+    assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y")))
+
+    // Test for FlatMapCoGroupsInPandas
+    val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null,
+      StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+      Seq.empty,
+      PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+      true)
+    val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas(
+      df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+    val df12 = df11.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y")))
+    assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
+
+    // Test for Generate
+    // Ensure that the root of the plan is Generate
+    val df13 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList"))
+      .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF
+    val df14 = df13.filter($"a" > 0)
+    assertAmbiguousSelfJoin(df13.join(df14, df13("a") === df14("col")))
+    assertAmbiguousSelfJoin(df14.join(df13, df13("a") === df14("col")))
+
+    // Test for Expand
+    // Ensure that the root of the plan is Expand
+    val df15 =
+      Expand(
+        Seq(Seq($"key1".expr, $"key2".expr)),
+        Seq(
+          AttributeReference("x", IntegerType)(),
+          AttributeReference("y", IntegerType)()),
+        df1.queryExecution.logical).toDF
+    val df16 = df15.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df15.join(df16, df15("x") === df16("y")))
+    assertAmbiguousSelfJoin(df16.join(df15, df15("x") === df16("y")))
+
+    // Test for Window
+    val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b")
+    // Ensure that the root of the plan is Window
+    val df17 = WindowPlan(
+      Seq(Alias(dfWithTS("time").expr, "ts")()),
+      Seq(dfWithTS("a").expr),
+      Seq(SortOrder(dfWithTS("a").expr, Ascending)),
+      dfWithTS.queryExecution.logical).toDF
+    val df18 = df17.filter($"a" > 0)
+    assertAmbiguousSelfJoin(df17.join(df18, df17("a") === df18("b")))
+    assertAmbiguousSelfJoin(df18.join(df17, df17("a") === df18("b")))
+
+    // Test for ScriptTransformation
+    val ioSchema =
+      ScriptInputOutputSchema(
+        Seq(("TOK_TABLEROWFORMATFIELD", ","),
+          ("TOK_TABLEROWFORMATCOLLITEMS", "#"),
+          ("TOK_TABLEROWFORMATMAPKEYS", "@"),
+          ("TOK_TABLEROWFORMATNULL", "null"),
+          ("TOK_TABLEROWFORMATLINES", "\n")),
+        Seq(("TOK_TABLEROWFORMATFIELD", ","),
+          ("TOK_TABLEROWFORMATCOLLITEMS", "#"),
+          ("TOK_TABLEROWFORMATMAPKEYS", "@"),
+          ("TOK_TABLEROWFORMATNULL", "null"),
+          ("TOK_TABLEROWFORMATLINES", "\n")), None, None,
+        List.empty, List.empty, None, None, false)
+    // Ensure that the root of the plan is ScriptTransformation
+    val df19 = ScriptTransformation(
+      Seq($"key1".expr, $"key2".expr),
+      "cat",
+      Seq(
+        AttributeReference("x", IntegerType)(),
+        AttributeReference("y", IntegerType)()),
+      df1.queryExecution.logical,
+      ioSchema).toDF
+    val df20 = df19.filter($"x" > 0)
+    assertAmbiguousSelfJoin(df19.join(df20, df19("x") === df20("y")))
+    assertAmbiguousSelfJoin(df20.join(df19, df19("x") === df20("y")))
+  }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org