You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by gu...@apache.org on 2021/10/07 06:57:40 UTC
[spark] branch branch-3.1 updated:
[SPARK-36874][SPARK-34634][SQL][3.1] ResolveReference.dedupRight should
copy dataset_id tag to avoid ambiguous self join
This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push:
new 33b30aa [SPARK-36874][SPARK-34634][SQL][3.1] ResolveReference.dedupRight should copy dataset_id tag to avoid ambiguous self join
33b30aa is described below
commit 33b30aae1196a95534e1bfd2f191f2fc5318e975
Author: Kousuke Saruta <sa...@oss.nttdata.com>
AuthorDate: Thu Oct 7 15:56:52 2021 +0900
[SPARK-36874][SPARK-34634][SQL][3.1] ResolveReference.dedupRight should copy dataset_id tag to avoid ambiguous self join
### What changes were proposed in this pull request?
This PR backports the change of SPARK-36874 (#34172) mainly, and SPARK-34634 (#31752) partially to care about the ambiguous self join for `ScriptTransformation`.
This PR fixes an issue that ambiguous self join can't be detected if the left and right DataFrame are swapped.
This is an example.
```
val df1 = Seq((1, 2, "A1"),(2, 1, "A2")).toDF("key1", "key2", "value")
val df2 = df1.filter($"value" === "A2")
df1.join(df2, df1("key1") === df2("key2")) // Ambiguous self join is detected and AnalysisException is thrown.
df2.join(df1, df1("key1") === df2("key2)) // Ambiguous self join is not detected.
```
The root cause seems that an inner function `collectConflictPlans` in `ResolveReference.dedupRight.` doesn't copy the `dataset_id` tag when it copies a `LogicalPlan`.
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
Closes #34205 from sarutak/backport-SPARK-36874.
Authored-by: Kousuke Saruta <sa...@oss.nttdata.com>
Signed-off-by: Hyukjin Kwon <gu...@apache.org>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 44 ++++++--
.../apache/spark/sql/DataFrameSelfJoinSuite.scala | 122 ++++++++++++++++++++-
2 files changed, 153 insertions(+), 13 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 62bfd53..5578838 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1379,13 +1379,16 @@ class Analyzer(override val catalogManager: CatalogManager)
case oldVersion: SerializeFromObject
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(
- serializer = oldVersion.serializer.map(_.newInstance()))))
+ val newVersion = oldVersion.copy(serializer = oldVersion.serializer.map(_.newInstance()))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
// Handle projects that create conflicting aliases.
case oldVersion @ Project(projectList, _)
if findAliases(projectList).intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(projectList = newAliases(projectList))))
+ val newVersion = oldVersion.copy(projectList = newAliases(projectList))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
// We don't need to search child plan recursively if the projectList of a Project
// is only composed of Alias and doesn't contain any conflicting attributes.
@@ -1397,8 +1400,9 @@ class Analyzer(override val catalogManager: CatalogManager)
case oldVersion @ Aggregate(_, aggregateExpressions, _)
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(
- aggregateExpressions = newAliases(aggregateExpressions))))
+ val newVersion = oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
// We don't search the child plan recursively for the same reason as the above Project.
case _ @ Aggregate(_, aggregateExpressions, _)
@@ -1407,20 +1411,28 @@ class Analyzer(override val catalogManager: CatalogManager)
case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+ val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+ val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
case oldVersion @ MapInPandas(_, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
- Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))
+ val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
- Seq((oldVersion, oldVersion.copy(generatorOutput = newOutput)))
+ val newVersion = oldVersion.copy(generatorOutput = newOutput)
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
case oldVersion: Expand
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
@@ -1432,12 +1444,22 @@ class Analyzer(override val catalogManager: CatalogManager)
attr
}
}
- Seq((oldVersion, oldVersion.copy(output = newOutput)))
+ val newVersion = oldVersion.copy(output = newOutput)
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
case oldVersion @ Window(windowExpressions, _, _, child)
if AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
- Seq((oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))))
+ val newVersion = oldVersion.copy(windowExpressions = newAliases(windowExpressions))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
+
+ case oldVersion @ ScriptTransformation(_, _, output, _, _)
+ if AttributeSet(output).intersect(conflictingAttributes).nonEmpty =>
+ val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
+ newVersion.copyTagsFrom(oldVersion)
+ Seq((oldVersion, newVersion))
case _ => plan.children.flatMap(collectConflictPlans)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 062404f..9994981 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -17,12 +17,15 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, AttributeReference, PythonUDF, SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate, ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
-import org.apache.spark.sql.functions.{count, sum}
+import org.apache.spark.sql.functions.{count, explode, sum}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.TestData
+import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
import testImplicits._
@@ -344,4 +347,119 @@ class DataFrameSelfJoinSuite extends QueryTest with SharedSparkSession {
assertAmbiguousSelfJoin(df1.join(df2).join(df5).join(df4).select(df2("b")))
}
}
+
+ test("SPARK-36874: DeduplicateRelations should copy dataset_id tag " +
+ "to avoid ambiguous self join") {
+ // Test for Project
+ val df1 = Seq((1, 2, "A1"), (2, 1, "A2")).toDF("key1", "key2", "value")
+ val df2 = df1.filter($"value" === "A2")
+ assertAmbiguousSelfJoin(df1.join(df2, df1("key1") === df2("key2")))
+ assertAmbiguousSelfJoin(df2.join(df1, df1("key1") === df2("key2")))
+
+ // Test for SerializeFromObject
+ val df3 = spark.sparkContext.parallelize(1 to 10).map(x => (x, x)).toDF
+ val df4 = df3.filter($"_1" <=> 0)
+ assertAmbiguousSelfJoin(df3.join(df4, df3("_1") === df4("_2")))
+ assertAmbiguousSelfJoin(df4.join(df3, df3("_1") === df4("_2")))
+
+ // Test For Aggregate
+ val df5 = df1.groupBy($"key1").agg(count($"value") as "count")
+ val df6 = df5.filter($"key1" > 0)
+ assertAmbiguousSelfJoin(df5.join(df6, df5("key1") === df6("count")))
+ assertAmbiguousSelfJoin(df6.join(df5, df5("key1") === df6("count")))
+
+ // Test for MapInPandas
+ val mapInPandasUDF = PythonUDF("mapInPandasUDF", null,
+ StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+ Seq.empty,
+ PythonEvalType.SQL_MAP_PANDAS_ITER_UDF,
+ true)
+ val df7 = df1.mapInPandas(mapInPandasUDF)
+ val df8 = df7.filter($"x" > 0)
+ assertAmbiguousSelfJoin(df7.join(df8, df7("x") === df8("y")))
+ assertAmbiguousSelfJoin(df8.join(df7, df7("x") === df8("y")))
+
+ // Test for FlatMapGroupsInPandas
+ val flatMapGroupsInPandasUDF = PythonUDF("flagMapGroupsInPandasUDF", null,
+ StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+ Seq.empty,
+ PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ true)
+ val df9 = df1.groupBy($"key1").flatMapGroupsInPandas(flatMapGroupsInPandasUDF)
+ val df10 = df9.filter($"x" > 0)
+ assertAmbiguousSelfJoin(df9.join(df10, df9("x") === df10("y")))
+ assertAmbiguousSelfJoin(df10.join(df9, df9("x") === df10("y")))
+
+ // Test for FlatMapCoGroupsInPandas
+ val flatMapCoGroupsInPandasUDF = PythonUDF("flagMapCoGroupsInPandasUDF", null,
+ StructType(Seq(StructField("x", LongType), StructField("y", LongType))),
+ Seq.empty,
+ PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
+ true)
+ val df11 = df1.groupBy($"key1").flatMapCoGroupsInPandas(
+ df1.groupBy($"key2"), flatMapCoGroupsInPandasUDF)
+ val df12 = df11.filter($"x" > 0)
+ assertAmbiguousSelfJoin(df11.join(df12, df11("x") === df12("y")))
+ assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
+
+ // Test for Generate
+ // Ensure that the root of the plan is Generate
+ val df13 = Seq((1, Seq(1, 2, 3))).toDF("a", "intList").select($"a", explode($"intList"))
+ .queryExecution.optimizedPlan.find(_.isInstanceOf[Generate]).get.toDF
+ val df14 = df13.filter($"a" > 0)
+ assertAmbiguousSelfJoin(df13.join(df14, df13("a") === df14("col")))
+ assertAmbiguousSelfJoin(df14.join(df13, df13("a") === df14("col")))
+
+ // Test for Expand
+ // Ensure that the root of the plan is Expand
+ val df15 =
+ Expand(
+ Seq(Seq($"key1".expr, $"key2".expr)),
+ Seq(
+ AttributeReference("x", IntegerType)(),
+ AttributeReference("y", IntegerType)()),
+ df1.queryExecution.logical).toDF
+ val df16 = df15.filter($"x" > 0)
+ assertAmbiguousSelfJoin(df15.join(df16, df15("x") === df16("y")))
+ assertAmbiguousSelfJoin(df16.join(df15, df15("x") === df16("y")))
+
+ // Test for Window
+ val dfWithTS = spark.sql("SELECT timestamp'2021-10-15 01:52:00' time, 1 a, 2 b")
+ // Ensure that the root of the plan is Window
+ val df17 = WindowPlan(
+ Seq(Alias(dfWithTS("time").expr, "ts")()),
+ Seq(dfWithTS("a").expr),
+ Seq(SortOrder(dfWithTS("a").expr, Ascending)),
+ dfWithTS.queryExecution.logical).toDF
+ val df18 = df17.filter($"a" > 0)
+ assertAmbiguousSelfJoin(df17.join(df18, df17("a") === df18("b")))
+ assertAmbiguousSelfJoin(df18.join(df17, df17("a") === df18("b")))
+
+ // Test for ScriptTransformation
+ val ioSchema =
+ ScriptInputOutputSchema(
+ Seq(("TOK_TABLEROWFORMATFIELD", ","),
+ ("TOK_TABLEROWFORMATCOLLITEMS", "#"),
+ ("TOK_TABLEROWFORMATMAPKEYS", "@"),
+ ("TOK_TABLEROWFORMATNULL", "null"),
+ ("TOK_TABLEROWFORMATLINES", "\n")),
+ Seq(("TOK_TABLEROWFORMATFIELD", ","),
+ ("TOK_TABLEROWFORMATCOLLITEMS", "#"),
+ ("TOK_TABLEROWFORMATMAPKEYS", "@"),
+ ("TOK_TABLEROWFORMATNULL", "null"),
+ ("TOK_TABLEROWFORMATLINES", "\n")), None, None,
+ List.empty, List.empty, None, None, false)
+ // Ensure that the root of the plan is ScriptTransformation
+ val df19 = ScriptTransformation(
+ Seq($"key1".expr, $"key2".expr),
+ "cat",
+ Seq(
+ AttributeReference("x", IntegerType)(),
+ AttributeReference("y", IntegerType)()),
+ df1.queryExecution.logical,
+ ioSchema).toDF
+ val df20 = df19.filter($"x" > 0)
+ assertAmbiguousSelfJoin(df19.join(df20, df19("x") === df20("y")))
+ assertAmbiguousSelfJoin(df20.join(df19, df19("x") === df20("y")))
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org