You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/22 15:00:49 UTC

[spark] branch branch-3.3 updated: [SPARK-39340][SQL] DS v2 agg pushdown should allow dots in the name of top-level columns

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new a7c21bb3ddd [SPARK-39340][SQL] DS v2 agg pushdown should allow dots in the name of top-level columns
a7c21bb3ddd is described below

commit a7c21bb3ddd3e2cd62018482818c84b19ed97c97
Author: Wenchen Fan <we...@databricks.com>
AuthorDate: Wed Jun 22 22:59:16 2022 +0800

    [SPARK-39340][SQL] DS v2 agg pushdown should allow dots in the name of top-level columns
    
    It turns out that I was wrong in https://github.com/apache/spark/pull/36727 . We still have the limitation (column name cannot contain dot) in master and 3.3 braches, in a very implicit way: The `V2ExpressionBuilder` has a boolean flag `nestedPredicatePushdownEnabled` whose default value is false. When it's false, it uses `PushableColumnWithoutNestedColumn` to match columns, which doesn't support dot in names.
    
    `V2ExpressionBuilder` is only used in 2 places:
    1. `PushableExpression`. This is a pattern match that is only used in v2 agg pushdown
    2. `PushablePredicate`. This is a pattern match that is used in various places, but all the caller sides set `nestedPredicatePushdownEnabled` to true.
    
    This PR removes the `nestedPredicatePushdownEnabled` flag from `V2ExpressionBuilder`, and makes it always support nested fields. `PushablePredicate` is also updated accordingly to remove the boolean flag, as it's always true.
    
    Fix a mistake to eliminate an unexpected limitation in DS v2 pushdown.
    
    No for end users. For data source developers, they can trigger agg pushdowm more often.
    
    a new test
    
    Closes #36945 from cloud-fan/dsv2.
    
    Authored-by: Wenchen Fan <we...@databricks.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 4567ed99a52d0274312ba78024c548f91659a12a)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../sql/catalyst/util/V2ExpressionBuilder.scala    | 25 +++++++-------
 .../datasources/v2/DataSourceV2Strategy.scala      | 38 ++++++++--------------
 .../execution/datasources/v2/PushDownUtils.scala   |  2 +-
 .../datasources/v2/DataSourceV2StrategySuite.scala |  2 +-
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 31 ++++++++++++------
 5 files changed, 49 insertions(+), 49 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index c77a040bc64..23560ae1d09 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -17,19 +17,15 @@
 
 package org.apache.spark.sql.catalyst.util
 
-import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
-import org.apache.spark.sql.execution.datasources.PushableColumn
 import org.apache.spark.sql.types.BooleanType
 
 /**
  * The builder to generate V2 expressions from catalyst expressions.
  */
-class V2ExpressionBuilder(
-  e: Expression, nestedPredicatePushdownEnabled: Boolean = false, isPredicate: Boolean = false) {
-
-  val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled)
+class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
 
   def build(): Option[V2Expression] = generateExpression(e, isPredicate)
 
@@ -49,12 +45,8 @@ class V2ExpressionBuilder(
     case Literal(true, BooleanType) => Some(new AlwaysTrue())
     case Literal(false, BooleanType) => Some(new AlwaysFalse())
     case Literal(value, dataType) => Some(LiteralValue(value, dataType))
-    case col @ pushableColumn(name) =>
-      val ref = if (nestedPredicatePushdownEnabled) {
-        FieldReference(name)
-      } else {
-        FieldReference.column(name)
-      }
+    case col @ ColumnOrField(nameParts) =>
+      val ref = FieldReference(nameParts)
       if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {
         Some(new V2Predicate("=", Array(ref, LiteralValue(true, BooleanType))))
       } else {
@@ -207,3 +199,12 @@ class V2ExpressionBuilder(
     case _ => None
   }
 }
+
+object ColumnOrField {
+  def unapply(e: Expression): Option[Seq[String]] = e match {
+    case a: Attribute => Some(Seq(a.name))
+    case s: GetStructField =>
+      unapply(s.child).map(_ :+ s.childSchema(s.ordinal).name)
+    case _ => None
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 95418027187..9be9cdda9e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -484,12 +484,9 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
 
 private[sql] object DataSourceV2Strategy {
 
-  private def translateLeafNodeFilterV2(
-      predicate: Expression,
-      supportNestedPredicatePushdown: Boolean): Option[Predicate] = {
-    val pushablePredicate = PushablePredicate(supportNestedPredicatePushdown)
+  private def translateLeafNodeFilterV2(predicate: Expression): Option[Predicate] = {
     predicate match {
-      case pushablePredicate(expr) => Some(expr)
+      case PushablePredicate(expr) => Some(expr)
       case _ => None
     }
   }
@@ -499,10 +496,8 @@ private[sql] object DataSourceV2Strategy {
    *
    * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
    */
-  protected[sql] def translateFilterV2(
-      predicate: Expression,
-      supportNestedPredicatePushdown: Boolean): Option[Predicate] = {
-    translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown)
+  protected[sql] def translateFilterV2(predicate: Expression): Option[Predicate] = {
+    translateFilterV2WithMapping(predicate, None)
   }
 
   /**
@@ -516,8 +511,7 @@ private[sql] object DataSourceV2Strategy {
    */
   protected[sql] def translateFilterV2WithMapping(
       predicate: Expression,
-      translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]],
-      nestedPredicatePushdownEnabled: Boolean)
+      translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]])
   : Option[Predicate] = {
     predicate match {
       case And(left, right) =>
@@ -531,26 +525,21 @@ private[sql] object DataSourceV2Strategy {
         // Pushing one leg of AND down is only safe to do at the top level.
         // You can see ParquetFilters' createFilter for more details.
         for {
-          leftFilter <- translateFilterV2WithMapping(
-            left, translatedFilterToExpr, nestedPredicatePushdownEnabled)
-          rightFilter <- translateFilterV2WithMapping(
-            right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+          leftFilter <- translateFilterV2WithMapping(left, translatedFilterToExpr)
+          rightFilter <- translateFilterV2WithMapping(right, translatedFilterToExpr)
         } yield new V2And(leftFilter, rightFilter)
 
       case Or(left, right) =>
         for {
-          leftFilter <- translateFilterV2WithMapping(
-            left, translatedFilterToExpr, nestedPredicatePushdownEnabled)
-          rightFilter <- translateFilterV2WithMapping(
-            right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+          leftFilter <- translateFilterV2WithMapping(left, translatedFilterToExpr)
+          rightFilter <- translateFilterV2WithMapping(right, translatedFilterToExpr)
         } yield new V2Or(leftFilter, rightFilter)
 
       case Not(child) =>
-        translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled)
-          .map(new V2Not(_))
+        translateFilterV2WithMapping(child, translatedFilterToExpr).map(new V2Not(_))
 
       case other =>
-        val filter = translateLeafNodeFilterV2(other, nestedPredicatePushdownEnabled)
+        val filter = translateLeafNodeFilterV2(other)
         if (filter.isDefined && translatedFilterToExpr.isDefined) {
           translatedFilterToExpr.get(filter.get) = predicate
         }
@@ -582,10 +571,9 @@ private[sql] object DataSourceV2Strategy {
 /**
  * Get the expression of DS V2 to represent catalyst predicate that can be pushed down.
  */
-case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) {
-
+object PushablePredicate {
   def unapply(e: Expression): Option[Predicate] =
-    new V2ExpressionBuilder(e, nestedPredicatePushdownEnabled, true).build().map { v =>
+    new V2ExpressionBuilder(e, true).build().map { v =>
       assert(v.isInstanceOf[Predicate])
       v.asInstanceOf[Predicate]
     }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index 0ebfed2fe9e..6b366fbd68a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -80,7 +80,7 @@ object PushDownUtils extends PredicateHelper {
         for (filterExpr <- filters) {
           val translated =
             DataSourceV2Strategy.translateFilterV2WithMapping(
-              filterExpr, Some(translatedFilterToExpr), nestedPredicatePushdownEnabled = true)
+              filterExpr, Some(translatedFilterToExpr))
           if (translated.isEmpty) {
             untranslatableExprs += filterExpr
           } else {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
index 6296da47cca..1a5a382afdc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
@@ -37,7 +37,7 @@ class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
    */
   def testTranslateFilter(catalystFilter: Expression, result: Option[Predicate]): Unit = {
     assertResult(result) {
-      DataSourceV2Strategy.translateFilterV2(catalystFilter, true)
+      DataSourceV2Strategy.translateFilterV2(catalystFilter)
     }
   }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 858aeaa1365..2f94f9ef31e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -81,9 +81,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
       conn.prepareStatement(
         "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate()
       conn.prepareStatement(
-        "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate()
-      conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate()
-      conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate()
+        "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL, \"dept.id\" INTEGER)")
+        .executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1, 1)").executeUpdate()
+      conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2, 1)").executeUpdate()
 
       // scalastyle:off
       conn.prepareStatement(
@@ -117,10 +118,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))
   }
 
-  private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String): Unit = {
+  private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String*): Unit = {
     df.queryExecution.optimizedPlan.collect {
       case _: DataSourceV2ScanRelation =>
-        checkKeywordsExistsInExplain(df, expectedPlanFragment)
+        checkKeywordsExistsInExplain(df, expectedPlanFragment: _*)
     }
   }
 
@@ -1177,11 +1178,21 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
   }
 
   test("column name with composite field") {
-    checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2)))
-    val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
-    checkAggregateRemoved(df)
-    checkPushedInfo(df, "PushedAggregates: [COUNT(`dept id`)]")
-    checkAnswer(df, Seq(Row(2)))
+    checkAnswer(sql("SELECT `dept id`, `dept.id` FROM h2.test.dept"), Seq(Row(1, 1), Row(2, 1)))
+
+    val df1 = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
+    checkPushedInfo(df1, "PushedAggregates: [COUNT(`dept id`)]")
+    checkAnswer(df1, Seq(Row(2)))
+
+    val df2 = sql("SELECT `dept.id`, COUNT(`dept id`) FROM h2.test.dept GROUP BY `dept.id`")
+    checkPushedInfo(df2,
+      "PushedGroupByExpressions: [`dept.id`]", "PushedAggregates: [COUNT(`dept id`)]")
+    checkAnswer(df2, Seq(Row(1, 2)))
+
+    val df3 = sql("SELECT `dept id`, COUNT(`dept.id`) FROM h2.test.dept GROUP BY `dept id`")
+    checkPushedInfo(df3,
+      "PushedGroupByExpressions: [`dept id`]", "PushedAggregates: [COUNT(`dept.id`)]")
+    checkAnswer(df3, Seq(Row(1, 1), Row(2, 1)))
   }
 
   test("column name with non-ascii") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org