You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by hv...@apache.org on 2017/01/21 00:11:44 UTC

spark git commit: [SPARK-18589][SQL] Fix Python UDF accessing attributes from both side of join

Repository: spark
Updated Branches:
  refs/heads/master 552e5f088 -> 9b7a03f15


[SPARK-18589][SQL] Fix Python UDF accessing attributes from both side of join

## What changes were proposed in this pull request?

PythonUDF is unevaluable, which can not be used inside a join condition, currently the optimizer will push a PythonUDF which accessing both side of join into the join condition, then the query will fail to plan.

This PR fix this issue by checking the expression is evaluable  or not before pushing it into Join.

## How was this patch tested?

Add a regression test.

Author: Davies Liu <da...@databricks.com>

Closes #16581 from davies/pyudf_join.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9b7a03f1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9b7a03f1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9b7a03f1

Branch: refs/heads/master
Commit: 9b7a03f15ac45e5f7dcf118d1e7ce1556339aa46
Parents: 552e5f0
Author: Davies Liu <da...@databricks.com>
Authored: Fri Jan 20 16:11:40 2017 -0800
Committer: Herman van Hovell <hv...@databricks.com>
Committed: Fri Jan 20 16:11:40 2017 -0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                           |  9 +++++++++
 .../spark/sql/catalyst/expressions/predicates.scala   | 13 ++++++++++++-
 .../spark/sql/catalyst/optimizer/Optimizer.scala      |  2 +-
 .../apache/spark/sql/catalyst/optimizer/joins.scala   |  5 ++---
 .../execution/python/BatchEvalPythonExecSuite.scala   | 14 ++++++--------
 5 files changed, 30 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9b7a03f1/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 73a5df6..4bfe6e9 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -342,6 +342,15 @@ class SQLTests(ReusedPySparkTestCase):
         df = df.withColumn('b', udf(lambda x: 'x')(df.a))
         self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b='x')])
 
+    def test_udf_in_filter_on_top_of_join(self):
+        # regression test for SPARK-18589
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.crossJoin(right).filter(f("a", "b"))
+        self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
         [row] = self.spark.sql("SELECT foo()").collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/9b7a03f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 3fcbb05..ac56ff1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
-import org.apache.spark.util.Utils
 
 
 object InterpretedPredicate {
@@ -86,6 +85,18 @@ trait PredicateHelper {
    */
   protected def canEvaluate(expr: Expression, plan: LogicalPlan): Boolean =
     expr.references.subsetOf(plan.outputSet)
+
+  /**
+   * Returns true iff `expr` could be evaluated as a condition within join.
+   */
+  protected def canEvaluateWithinJoin(expr: Expression): Boolean = expr match {
+    case e: SubqueryExpression =>
+      // non-correlated subquery will be replaced as literal
+      e.children.isEmpty
+    case a: AttributeReference => true
+    case e: Unevaluable => false
+    case e => e.children.forall(canEvaluateWithinJoin)
+  }
 }
 
 @ExpressionDescription(

http://git-wip-us.apache.org/repos/asf/spark/blob/9b7a03f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 009c517..20b3898 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -893,7 +893,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
           val newRight = rightFilterConditions.
             reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
           val (newJoinConditions, others) =
-            commonFilterCondition.partition(e => !SubqueryExpression.hasCorrelatedSubquery(e))
+            commonFilterCondition.partition(canEvaluateWithinJoin)
           val newJoinCond = (newJoinConditions ++ joinCondition).reduceLeftOption(And)
 
           val join = Join(newLeft, newRight, joinType, newJoinCond)

http://git-wip-us.apache.org/repos/asf/spark/blob/9b7a03f1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 180ad2e..bfe529e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -46,8 +46,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
     : LogicalPlan = {
     assert(input.size >= 2)
     if (input.size == 2) {
-      val (joinConditions, others) = conditions.partition(
-        e => !SubqueryExpression.hasCorrelatedSubquery(e))
+      val (joinConditions, others) = conditions.partition(canEvaluateWithinJoin)
       val ((left, leftJoinType), (right, rightJoinType)) = (input(0), input(1))
       val innerJoinType = (leftJoinType, rightJoinType) match {
         case (Inner, Inner) => Inner
@@ -75,7 +74,7 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
 
       val joinedRefs = left.outputSet ++ right.outputSet
       val (joinConditions, others) = conditions.partition(
-        e => e.references.subsetOf(joinedRefs) && !SubqueryExpression.hasCorrelatedSubquery(e))
+        e => e.references.subsetOf(joinedRefs) && canEvaluateWithinJoin(e))
       val joined = Join(left, right, innerJoinType, joinConditions.reduceLeftOption(And))
 
       // should not have reference to same logical plan

http://git-wip-us.apache.org/repos/asf/spark/blob/9b7a03f1/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 81bea2f..2a3d1cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -21,7 +21,7 @@ import scala.collection.JavaConverters._
 import scala.collection.mutable.ArrayBuffer
 
 import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, EqualTo, Expression, GreaterThan, In}
 import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
 import org.apache.spark.sql.test.SharedSQLContext
 import org.apache.spark.sql.types.BooleanType
@@ -86,13 +86,11 @@ class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
   test("Python UDF refers to the attributes from more than one child") {
     val df = Seq(("Hello", 4)).toDF("a", "b")
     val df2 = Seq(("Hello", 4)).toDF("c", "d")
-    val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
-
-    val e = intercept[RuntimeException] {
-      joinDF.queryExecution.executedPlan
-    }.getMessage
-    assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child")
-      .forall(e.contains))
+    val joinDF = df.crossJoin(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
+    val qualifiedPlanNodes = joinDF.queryExecution.executedPlan.collect {
+      case b: BatchEvalPythonExec => b
+    }
+    assert(qualifiedPlanNodes.size == 1)
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org