You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/12/14 16:25:57 UTC
[spark] branch branch-2.4 updated: [SPARK-26370][SQL] Fix
resolution of higher-order function for the same identifier.
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-2.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-2.4 by this push:
new aec68a8 [SPARK-26370][SQL] Fix resolution of higher-order function for the same identifier.
aec68a8 is described below
commit aec68a8ff18360cd2d1f2b103e6fe64d78e3d770
Author: Takuya UESHIN <ue...@databricks.com>
AuthorDate: Sat Dec 15 00:23:28 2018 +0800
[SPARK-26370][SQL] Fix resolution of higher-order function for the same identifier.
When using a higher-order function with the same variable name as the existing columns in `Filter` or something which uses `Analyzer.resolveExpressionBottomUp` during the resolution, e.g.,:
```scala
val df = Seq(
(Seq(1, 9, 8, 7), 1, 2),
(Seq(5, 9, 7), 2, 2),
(Seq.empty, 3, 2),
(null, 4, 2)
).toDF("i", "x", "d")
checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
Seq(Row(1)))
```
the following exception happens:
```
java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.BoundReference cannot be cast to org.apache.spark.sql.catalyst.expressions.NamedExpression
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
at scala.collection.TraversableLike.map(TraversableLike.scala:237)
at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
at scala.collection.AbstractTraversable.map(Traversable.scala:108)
at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.$anonfun$functionsForEval$1(higherOrderFunctions.scala:147)
at scala.collection.TraversableLike.$anonfun$map$1(TraversableLike.scala:237)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.TraversableLike.map(TraversableLike.scala:237)
at scala.collection.TraversableLike.map$(TraversableLike.scala:230)
at scala.collection.immutable.List.map(List.scala:298)
at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval(higherOrderFunctions.scala:145)
at org.apache.spark.sql.catalyst.expressions.HigherOrderFunction.functionsForEval$(higherOrderFunctions.scala:145)
at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval$lzycompute(higherOrderFunctions.scala:369)
at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionsForEval(higherOrderFunctions.scala:369)
at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval(higherOrderFunctions.scala:176)
at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.functionForEval$(higherOrderFunctions.scala:176)
at org.apache.spark.sql.catalyst.expressions.ArrayExists.functionForEval(higherOrderFunctions.scala:369)
at org.apache.spark.sql.catalyst.expressions.ArrayExists.nullSafeEval(higherOrderFunctions.scala:387)
at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval(higherOrderFunctions.scala:190)
at org.apache.spark.sql.catalyst.expressions.SimpleHigherOrderFunction.eval$(higherOrderFunctions.scala:185)
at org.apache.spark.sql.catalyst.expressions.ArrayExists.eval(higherOrderFunctions.scala:369)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificPredicate.eval(Unknown Source)
at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3(basicPhysicalOperators.scala:216)
at org.apache.spark.sql.execution.FilterExec.$anonfun$doExecute$3$adapted(basicPhysicalOperators.scala:215)
...
```
because the `UnresolvedAttribute`s in `LambdaFunction` are unexpectedly resolved by the rule.
This pr modified to use a placeholder `UnresolvedNamedLambdaVariable` to prevent unexpected resolution.
Added a test and modified some tests.
Closes #23320 from ueshin/issues/SPARK-26370/hof_resolution.
Authored-by: Takuya UESHIN <ue...@databricks.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
(cherry picked from commit 3dda58af2b7f42beab736d856bf17b4d35c8866c)
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../catalyst/analysis/higherOrderFunctions.scala | 5 +++--
.../expressions/higherOrderFunctions.scala | 26 ++++++++++++++++++++--
.../spark/sql/catalyst/parser/AstBuilder.scala | 7 ++++--
.../analysis/ResolveLambdaVariablesSuite.scala | 10 +++++----
.../catalyst/parser/ExpressionParserSuite.scala | 6 +++--
.../apache/spark/sql/DataFrameFunctionsSuite.scala | 20 +++++++++++++++++
6 files changed, 62 insertions(+), 12 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index dd08190..c8c7580 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -148,13 +148,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))
- case u @ UnresolvedAttribute(name +: nestedFields) =>
+ case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
parentLambdaMap.get(canonicalizer(name)) match {
case Some(lambda) =>
nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), conf.resolver)
}
- case None => u
+ case None =>
+ UnresolvedAttribute(u.nameParts)
}
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 32f9753..17cd2a7 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,13 +22,35 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
/**
+ * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
+ */
+case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
+ extends LeafExpression with NamedExpression with Unevaluable {
+
+ override def name: String =
+ nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
+
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier")
+ override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
+ override lazy val resolved = false
+
+ override def toString: String = s"lambda '$name"
+
+ override def sql: String = name
+}
+
+/**
* A named lambda variable.
*/
case class NamedLambdaVariable(
@@ -79,7 +101,7 @@ case class LambdaFunction(
object LambdaFunction {
val identity: LambdaFunction = {
- val id = UnresolvedAttribute.quoted("id")
+ val id = UnresolvedNamedLambdaVariable(Seq("id"))
LambdaFunction(id, Seq(id))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index c6d2105..80a4d18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1336,9 +1336,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.IDENTIFIER().asScala.map { name =>
- UnresolvedAttribute.quoted(name.getText)
+ UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
}
- LambdaFunction(expression(ctx.expression), arguments)
+ val function = expression(ctx.expression).transformUp {
+ case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+ }
+ LambdaFunction(function, arguments)
}
/**
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
index c4171c7..a5847ba 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
@@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
comparePlans(Analyzer.execute(plan(e1)), plan(e2))
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("resolution - no op") {
checkExpression(key, key)
}
test("resolution - simple") {
- val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil))
+ val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
checkExpression(in, out)
}
test("resolution - nested") {
val in = ArrayTransform(values2, LambdaFunction(
- ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil))
+ ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
val out = ArrayTransform(values2, LambdaFunction(
ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
checkExpression(in, out)
@@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {
test("fail - name collisions") {
val p = plan(ArrayTransform(values1,
- LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
+ LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("arguments should not have names that are semantically the same"))
}
test("fail - lambda arguments") {
val p = plan(ArrayTransform(values1,
- LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil)))
+ LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("does not match the number of arguments expected"))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index 781fc1e..1eec9e7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
intercept("foo(a x)", "extraneous input 'x'")
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("lambda functions") {
- assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
- assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr)))
+ assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
+ assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y))))
}
test("window function expressions") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index d4f9b90..99abfda 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2486,6 +2486,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
assert(ex.getMessage.contains("Cannot use null as map key"))
}
+
+ test("SPARK-26370: Fix resolution of higher-order function for the same identifier") {
+ val df = Seq(
+ (Seq(1, 9, 8, 7), 1, 2),
+ (Seq(5, 9, 7), 2, 2),
+ (Seq.empty, 3, 2),
+ (null, 4, 2)
+ ).toDF("i", "x", "d")
+
+ checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
+ Seq(
+ Row(1, true),
+ Row(2, false),
+ Row(3, false),
+ Row(4, null)))
+ checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
+ Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
+ checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
+ Seq(Row(1)))
+ }
}
object DataFrameFunctionsSuite {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org