You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2018/09/27 07:13:26 UTC

spark git commit: [SPARK-25314][SQL] Fix Python UDF accessing attributes from both side of join in join conditions

Repository: spark
Updated Branches:
  refs/heads/master d03e0af80 -> 2a8cbfddb


[SPARK-25314][SQL] Fix Python UDF accessing attributes from both side of join in join conditions

## What changes were proposed in this pull request?

Thanks for bahchis reporting this. It is more like a follow up work for #16581, this PR fix the scenario of Python UDF accessing attributes from both side of join in join condition.

## How was this patch tested?

Add  regression tests in PySpark and `BatchEvalPythonExecSuite`.

Closes #22326 from xuanyuanking/SPARK-25314.

Authored-by: Yuanjian Li <xy...@gmail.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2a8cbfdd
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2a8cbfdd
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2a8cbfdd

Branch: refs/heads/master
Commit: 2a8cbfddba2a59d144b32910c68c22d0199093fe
Parents: d03e0af
Author: Yuanjian Li <xy...@gmail.com>
Authored: Thu Sep 27 15:13:18 2018 +0800
Committer: Wenchen Fan <we...@databricks.com>
Committed: Thu Sep 27 15:13:18 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 64 ++++++++++++++++++++
 .../sql/catalyst/optimizer/Optimizer.scala      |  8 ++-
 .../spark/sql/catalyst/optimizer/joins.scala    | 49 +++++++++++++++
 3 files changed, 119 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/2a8cbfdd/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 64a7ceb..b88a655 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -596,6 +596,70 @@ class SQLTests(ReusedSQLTestCase):
         df = left.crossJoin(right).filter(f("a", "b"))
         self.assertEqual(df.collect(), [Row(a=1, b=1)])
 
+    def test_udf_in_join_condition(self):
+        # regression test for SPARK-25314
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1)])
+        right = self.spark.createDataFrame([Row(b=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, f("a", "b"))
+        with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
+            df.collect()
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+            self.assertEqual(df.collect(), [Row(a=1, b=1)])
+
+    def test_udf_in_left_semi_join_condition(self):
+        # regression test for SPARK-25314
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, f("a", "b"), "leftsemi")
+        with self.assertRaisesRegexp(AnalysisException, 'Detected implicit cartesian product'):
+            df.collect()
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
+            self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+    def test_udf_and_common_filter_in_join_condition(self):
+        # regression test for SPARK-25314
+        # test the complex scenario with both udf and common filter
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, [f("a", "b"), left.a1 == right.b1])
+        # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
+        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+
+    def test_udf_and_common_filter_in_left_semi_join_condition(self):
+        # regression test for SPARK-25314
+        # test the complex scenario with both udf and common filter
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+        df = left.join(right, [f("a", "b"), left.a1 == right.b1], "left_semi")
+        # do not need spark.sql.crossJoin.enabled=true for udf is not the only join condition.
+        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1)])
+
+    def test_udf_not_supported_in_join_condition(self):
+        # regression test for SPARK-25314
+        # test python udf is not supported in join type besides left_semi and inner join.
+        from pyspark.sql.functions import udf
+        left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
+        right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
+        f = udf(lambda a, b: a == b, BooleanType())
+
+        def runWithJoinType(join_type, type_string):
+            with self.assertRaisesRegexp(
+                    AnalysisException,
+                    'Using PythonUDF.*%s is not supported.' % type_string):
+                left.join(right, [f("a", "b"), left.a1 == right.b1], join_type).collect()
+        runWithJoinType("full", "FullOuter")
+        runWithJoinType("left", "LeftOuter")
+        runWithJoinType("right", "RightOuter")
+        runWithJoinType("leftanti", "LeftAnti")
+
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
         [row] = self.spark.sql("SELECT foo()").collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/2a8cbfdd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 07a653f..da8009d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -165,7 +165,10 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
     Batch("LocalRelation", fixedPoint,
       ConvertToLocalRelation,
       PropagateEmptyRelation) :+
-    // The following batch should be executed after batch "Join Reorder" and "LocalRelation".
+    Batch("Extract PythonUDF From JoinCondition", Once,
+      PullOutPythonUDFInJoinCondition) :+
+    // The following batch should be executed after batch "Join Reorder" "LocalRelation" and
+    // "Extract PythonUDF From JoinCondition".
     Batch("Check Cartesian Products", Once,
       CheckCartesianProducts) :+
     Batch("RewriteSubquery", Once,
@@ -202,7 +205,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
       ReplaceDistinctWithAggregate.ruleName ::
       PullupCorrelatedPredicates.ruleName ::
       RewriteCorrelatedScalarSubquery.ruleName ::
-      RewritePredicateSubquery.ruleName :: Nil
+      RewritePredicateSubquery.ruleName ::
+      PullOutPythonUDFInJoinCondition.ruleName :: Nil
 
   /**
    * Optimize all the subqueries inside expression.

http://git-wip-us.apache.org/repos/asf/spark/blob/2a8cbfdd/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index edbeaf2..7149ede 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
 
 import scala.annotation.tailrec
 
+import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.planning.ExtractFiltersAndInnerJoins
 import org.apache.spark.sql.catalyst.plans._
@@ -152,3 +153,51 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper {
       if (j.joinType == newJoinType) f else Filter(condition, j.copy(joinType = newJoinType))
   }
 }
+
+/**
+ * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF
+ * and pull them out from join condition. For python udf accessing attributes from only one side,
+ * they are pushed down by operation push down rules. If not (e.g. user disables filter push
+ * down rules), we need to pull them out in this rule too.
+ */
+object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper {
+  def hasPythonUDF(expression: Expression): Boolean = {
+    expression.collectFirst { case udf: PythonUDF => udf }.isDefined
+  }
+
+  override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+    case j @ Join(_, _, joinType, condition)
+        if condition.isDefined && hasPythonUDF(condition.get) =>
+      if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) {
+        // The current strategy only support InnerLike and LeftSemi join because for other type,
+        // it breaks SQL semantic if we run the join condition as a filter after join. If we pass
+        // the plan here, it'll still get a an invalid PythonUDF RuntimeException with message
+        // `requires attributes from more than one child`, we throw firstly here for better
+        // readable information.
+        throw new AnalysisException("Using PythonUDF in join condition of join type" +
+          s" $joinType is not supported.")
+      }
+      // If condition expression contains python udf, it will be moved out from
+      // the new join conditions.
+      val (udf, rest) =
+        splitConjunctivePredicates(condition.get).partition(hasPythonUDF)
+      val newCondition = if (rest.isEmpty) {
+        logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," +
+          s" it will be moved out and the join plan will be turned to cross join.")
+        None
+      } else {
+        Some(rest.reduceLeft(And))
+      }
+      val newJoin = j.copy(condition = newCondition)
+      joinType match {
+        case _: InnerLike => Filter(udf.reduceLeft(And), newJoin)
+        case LeftSemi =>
+          Project(
+            j.left.output.map(_.toAttribute),
+            Filter(udf.reduceLeft(And), newJoin.copy(joinType = Inner)))
+        case _ =>
+          throw new AnalysisException("Using PythonUDF in join condition of join type" +
+            s" $joinType is not supported.")
+      }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org