You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/30 10:03:19 UTC

[spark] branch master updated: [SPARK-39139][SQL] DS V2 supports push down DS V2 UDF

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new b242f85547a [SPARK-39139][SQL] DS V2 supports push down DS V2 UDF
b242f85547a is described below

commit b242f85547a50376a583cb19a607587aafeaa1f8
Author: Jiaan Geng <be...@163.com>
AuthorDate: Thu Jun 30 18:02:53 2022 +0800

    [SPARK-39139][SQL] DS V2 supports push down DS V2 UDF
    
    ### What changes were proposed in this pull request?
    Currently, Spark DS V2 push-down framework supports push down SQL to data sources.
    But the DS V2 push-down framework only support push down the built-in functions to data sources.
    Each database have a lot very useful functions which not supported by Spark.
    If we can push down these functions into data source, it will reduce disk I/O and network I/O and improve the performance when query databases.
    
    ### Why are the changes needed?
    
    1. Spark can leverage the functions supported by databases
    2. Improve the query performance.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New tests.
    
    Closes #36593 from beliefer/SPARK-39139.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../expressions/GeneralScalarExpression.java       |  11 +-
 .../expressions/UserDefinedScalarFunc.java         |  70 ++++++++
 .../aggregate/GeneralAggregateFunc.java            |  20 +--
 ...gateFunc.java => UserDefinedAggregateFunc.java} |  48 ++----
 .../sql/connector/util/V2ExpressionSQLBuilder.java |  37 +++++
 .../internal/connector/ToStringSQLBuilder.scala    |  38 +++++
 .../sql/catalyst/util/V2ExpressionBuilder.scala    |  10 +-
 .../execution/datasources/DataSourceStrategy.scala |  10 +-
 .../datasources/v2/V2ScanRelationPushDown.scala    |   3 +-
 .../sql/connector/DataSourceV2FunctionSuite.scala  |  38 ++---
 .../org/apache/spark/sql/jdbc/JDBCV2Suite.scala    | 179 +++++++++++++++++++--
 11 files changed, 380 insertions(+), 84 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index d4c27196eaf..ab9e33e86be 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -23,7 +23,7 @@ import java.util.Objects;
 
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.filter.Predicate;
-import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
 
 /**
  * The general representation of SQL scalar expressions, which contains the upper-cased
@@ -381,12 +381,7 @@ public class GeneralScalarExpression implements Expression, Serializable {
 
   @Override
   public String toString() {
-    V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder();
-    try {
-      return builder.build(this);
-    } catch (Throwable e) {
-      return name + "(" +
-        Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")";
-    }
+    ToStringSQLBuilder builder = new ToStringSQLBuilder();
+    return builder.build(this);
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
new file mode 100644
index 00000000000..8e4155f81b8
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Objects;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
+
+/**
+ * The general representation of user defined scalar function, which contains the upper-cased
+ * function name, canonical function name and all the children expressions.
+ *
+ * @since 3.4.0
+ */
+@Evolving
+public class UserDefinedScalarFunc implements Expression, Serializable {
+  private String name;
+  private String canonicalName;
+  private Expression[] children;
+
+  public UserDefinedScalarFunc(String name, String canonicalName, Expression[] children) {
+    this.name = name;
+    this.canonicalName = canonicalName;
+    this.children = children;
+  }
+
+  public String name() { return name; }
+  public String canonicalName() { return canonicalName; }
+
+  @Override
+  public Expression[] children() { return children; }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+    UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
+    return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
+      Arrays.equals(children, that.children);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(name, canonicalName, children);
+  }
+
+  @Override
+  public String toString() {
+    ToStringSQLBuilder builder = new ToStringSQLBuilder();
+    return builder.build(this);
+  }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 70166445434..81838074fb1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -17,11 +17,9 @@
 
 package org.apache.spark.sql.connector.expressions.aggregate;
 
-import java.util.Arrays;
-import java.util.stream.Collectors;
-
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
 
 /**
  * The general implementation of {@link AggregateFunc}, which contains the upper-cased function
@@ -47,27 +45,21 @@ public final class GeneralAggregateFunc implements AggregateFunc {
   private final boolean isDistinct;
   private final Expression[] children;
 
-  public String name() { return name; }
-  public boolean isDistinct() { return isDistinct; }
-
   public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
     this.name = name;
     this.isDistinct = isDistinct;
     this.children = children;
   }
 
+  public String name() { return name; }
+  public boolean isDistinct() { return isDistinct; }
+
   @Override
   public Expression[] children() { return children; }
 
   @Override
   public String toString() {
-    String inputsString = Arrays.stream(children)
-      .map(Expression::describe)
-      .collect(Collectors.joining(", "));
-    if (isDistinct) {
-      return name + "(DISTINCT " + inputsString + ")";
-    } else {
-      return name + "(" + inputsString + ")";
-    }
+    ToStringSQLBuilder builder = new ToStringSQLBuilder();
+    return builder.build(this);
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
similarity index 52%
copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
index 70166445434..9a89e7a89c9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
@@ -17,57 +17,43 @@
 
 package org.apache.spark.sql.connector.expressions.aggregate;
 
-import java.util.Arrays;
-import java.util.stream.Collectors;
-
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
 
 /**
- * The general implementation of {@link AggregateFunc}, which contains the upper-cased function
- * name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial
- * aggregate with this function to the source, but can only push down the entire aggregate.
- * <p>
- * The currently supported SQL aggregate functions:
- * <ol>
- *  <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
- *  <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
- *  <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
- *  <li><pre>STDDEV_SAMP(input1)</pre> Since 3.3.0</li>
- *  <li><pre>COVAR_POP(input1, input2)</pre> Since 3.3.0</li>
- *  <li><pre>COVAR_SAMP(input1, input2)</pre> Since 3.3.0</li>
- *  <li><pre>CORR(input1, input2)</pre> Since 3.3.0</li>
- * </ol>
+ * The general representation of user defined aggregate function, which implements
+ * {@link AggregateFunc}, contains the upper-cased function name, the canonical function name,
+ * the `isDistinct` flag and all the inputs. Note that Spark cannot push down aggregate with
+ * this function partially to the source, but can only push down the entire aggregate.
  *
- * @since 3.3.0
+ * @since 3.4.0
  */
 @Evolving
-public final class GeneralAggregateFunc implements AggregateFunc {
+public class UserDefinedAggregateFunc implements AggregateFunc {
   private final String name;
+  private String canonicalName;
   private final boolean isDistinct;
   private final Expression[] children;
 
-  public String name() { return name; }
-  public boolean isDistinct() { return isDistinct; }
-
-  public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
+  public UserDefinedAggregateFunc(
+      String name, String canonicalName, boolean isDistinct, Expression[] children) {
     this.name = name;
+    this.canonicalName = canonicalName;
     this.isDistinct = isDistinct;
     this.children = children;
   }
 
+  public String name() { return name; }
+  public String canonicalName() { return canonicalName; }
+  public boolean isDistinct() { return isDistinct; }
+
   @Override
   public Expression[] children() { return children; }
 
   @Override
   public String toString() {
-    String inputsString = Arrays.stream(children)
-      .map(Expression::describe)
-      .collect(Collectors.joining(", "));
-    if (isDistinct) {
-      return name + "(DISTINCT " + inputsString + ")";
-    } else {
-      return name + "(" + inputsString + ")";
-    }
+    ToStringSQLBuilder builder = new ToStringSQLBuilder();
+    return builder.build(this);
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 8ae70a7529c..9b62fedcc80 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -26,6 +26,9 @@ import org.apache.spark.sql.connector.expressions.Expression;
 import org.apache.spark.sql.connector.expressions.NamedReference;
 import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
 import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
+import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
+import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
 import org.apache.spark.sql.types.DataType;
 
 /**
@@ -156,6 +159,18 @@ public class V2ExpressionSQLBuilder {
         default:
           return visitUnexpectedExpr(expr);
       }
+    } else if (expr instanceof GeneralAggregateFunc) {
+      GeneralAggregateFunc f = (GeneralAggregateFunc) expr;
+      return visitGeneralAggregateFunction(f.name(), f.isDistinct(),
+        Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+    } else if (expr instanceof UserDefinedScalarFunc) {
+      UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr;
+      return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
+        Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+    } else if (expr instanceof UserDefinedAggregateFunc) {
+      UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr;
+      return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(),
+        Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
     } else {
       return visitUnexpectedExpr(expr);
     }
@@ -268,6 +283,28 @@ public class V2ExpressionSQLBuilder {
     return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
   }
 
+  protected String visitGeneralAggregateFunction(
+      String funcName, boolean isDistinct, String[] inputs) {
+    if (isDistinct) {
+      return funcName +
+        "(DISTINCT " + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
+    } else {
+      return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
+    }
+  }
+
+  protected String visitUserDefinedScalarFunction(
+      String funcName, String canonicalName, String[] inputs) {
+    throw new UnsupportedOperationException(
+      this.getClass().getSimpleName() + " does not support user defined function: " + funcName);
+  }
+
+  protected String visitUserDefinedAggregateFunction(
+      String funcName, String canonicalName, boolean isDistinct, String[] inputs) {
+    throw new UnsupportedOperationException(this.getClass().getSimpleName() +
+      " does not support user defined aggregate function: " + funcName);
+  }
+
   protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
     throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
new file mode 100644
index 00000000000..889fdd4ebf2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
+
+/**
+ * The builder to generate `toString` information of V2 expressions.
+ */
+class ToStringSQLBuilder extends V2ExpressionSQLBuilder {
+  override protected def visitUserDefinedScalarFunction(
+      funcName: String, canonicalName: String, inputs: Array[String]) =
+    s"""$funcName(${inputs.mkString(", ")})"""
+
+  override protected def visitUserDefinedAggregateFunction(
+      funcName: String,
+      canonicalName: String,
+      isDistinct: Boolean,
+      inputs: Array[String]): String = {
+    val distinct = if (isDistinct) "DISTINCT " else ""
+    s"""$funcName($distinct${inputs.mkString(", ")})"""
+  }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 1a9b9202cbe..163e071f08e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.util
 
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
 import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
 import org.apache.spark.sql.types.BooleanType
 
@@ -345,6 +345,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
         None
       }
     // TODO supports other expressions
+    case ApplyFunctionExpression(function, children) =>
+      val childrenExpressions = children.flatMap(generateExpression(_))
+      if (childrenExpressions.length == children.length) {
+        Some(new UserDefinedScalarFunc(
+          function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression]))
+      } else {
+        None
+      }
     case _ => None
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index ba4c76285d5..8882261d967 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, V2ExpressionBu
 import org.apache.spark.sql.connector.catalog.SupportsRead
 import org.apache.spark.sql.connector.catalog.TableCapability._
 import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
 import org.apache.spark.sql.execution.command._
@@ -751,6 +751,14 @@ object DataSourceStrategy
         PushableColumnWithoutNestedColumn(right), _) =>
           Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
             Array(FieldReference.column(left), FieldReference.column(right))))
+        case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
+          val translatedExprs = children.flatMap(PushableExpression.unapply(_))
+          if (translatedExprs.length == children.length) {
+            Some(new UserDefinedAggregateFunc(aggrFunc.name(),
+              aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression]))
+          } else {
+            None
+          }
         case _ => None
       }
     } else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index b55aeefca0e..8951c37e127 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc}
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
 import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -299,6 +299,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
       case count: Count => !count.isDistinct
       case avg: Avg => !avg.isDistinct
       case _: GeneralAggregateFunc => false
+      case _: UserDefinedAggregateFunc => false
       case _ => true
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index cb9aea914e1..4a070ace377 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -37,6 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
 
 object IntAverage extends AggregateFunction[(Int, Int), Int] {
   override def name(): String = "iavg"
+  override def canonicalName(): String = "h2.iavg"
   override def inputTypes(): Array[DataType] = Array(IntegerType)
   override def resultType(): DataType = IntegerType
 
@@ -65,6 +66,7 @@ object IntAverage extends AggregateFunction[(Int, Int), Int] {
 
 object LongAverage extends AggregateFunction[(Long, Long), Long] {
   override def name(): String = "iavg"
+  override def canonicalName(): String = "h2.iavg"
   override def inputTypes(): Array[DataType] = Array(LongType)
   override def resultType(): DataType = LongType
 
@@ -113,6 +115,24 @@ object IntegralAverage extends UnboundFunction {
       |  iavg(bigint) -> bigint""".stripMargin
 }
 
+case class StrLen(impl: BoundFunction) extends UnboundFunction {
+  override def name(): String = "strlen"
+
+  override def bind(inputType: StructType): BoundFunction = {
+    if (inputType.fields.length != 1) {
+      throw new UnsupportedOperationException("Expect exactly one argument");
+    }
+    inputType.fields(0).dataType match {
+      case StringType => impl
+      case _ =>
+        throw new UnsupportedOperationException("Expect StringType")
+    }
+  }
+
+  override def description(): String =
+    "strlen: returns the length of the input string  strlen(string) -> int"
+}
+
 class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
   private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String]
 
@@ -532,24 +552,6 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
     }
   }
 
-  private case class StrLen(impl: BoundFunction) extends UnboundFunction {
-    override def description(): String =
-      """strlen: returns the length of the input string
-        |  strlen(string) -> int""".stripMargin
-    override def name(): String = "strlen"
-
-    override def bind(inputType: StructType): BoundFunction = {
-      if (inputType.fields.length != 1) {
-        throw new UnsupportedOperationException("Expect exactly one argument");
-      }
-      inputType.fields(0).dataType match {
-        case StringType => impl
-        case _ =>
-          throw new UnsupportedOperationException("Expect StringType")
-      }
-    }
-  }
-
   private case object StrLenDefault extends ScalarFunction[Int] {
     override def inputTypes(): Array[DataType] = Array(StringType)
     override def resultType(): DataType = IntegerType
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index c625f55ef8b..90ab976d9d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -20,16 +20,23 @@ package org.apache.spark.sql.jdbc
 import java.sql.{Connection, DriverManager}
 import java.util.Properties
 
+import scala.util.control.NonFatal
+
 import org.apache.spark.{SparkConf, SparkException}
 import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row}
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
 import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sort}
-import org.apache.spark.sql.connector.IntegralAverage
+import org.apache.spark.sql.connector.{IntegralAverage, StrLen}
+import org.apache.spark.sql.connector.catalog.functions.{ScalarFunction, UnboundFunction}
+import org.apache.spark.sql.connector.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, UserDefinedAggregateFunc}
 import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
 import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
 import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil, coalesce, cos, cosh, cot, count, count_distinct, degrees, exp, floor, lit, log => logarithm, log10, not, pow, radians, round, signum, sin, sinh, sqrt, sum, tan, tanh, udf, when}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
 import org.apache.spark.util.Utils
 
 class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
@@ -39,6 +46,75 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
   val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass"
   var conn: java.sql.Connection = null
 
+  val testH2Dialect = new JdbcDialect {
+    override def canHandle(url: String): Boolean = H2Dialect.canHandle(url)
+
+    class H2SQLBuilder extends JDBCSQLBuilder {
+      override def visitUserDefinedScalarFunction(
+          funcName: String, canonicalName: String, inputs: Array[String]): String = {
+        canonicalName match {
+          case "h2.char_length" =>
+            s"$funcName(${inputs.mkString(", ")})"
+          case _ => super.visitUserDefinedScalarFunction(funcName, canonicalName, inputs)
+        }
+      }
+
+      override def visitUserDefinedAggregateFunction(
+          funcName: String,
+          canonicalName: String,
+          isDistinct: Boolean,
+          inputs: Array[String]): String = {
+        canonicalName match {
+          case "h2.iavg" =>
+            if (isDistinct) {
+              s"$funcName(DISTINCT ${inputs.mkString(", ")})"
+            } else {
+              s"$funcName(${inputs.mkString(", ")})"
+            }
+          case _ =>
+            super.visitUserDefinedAggregateFunction(funcName, canonicalName, isDistinct, inputs)
+        }
+      }
+    }
+
+    override def compileExpression(expr: Expression): Option[String] = {
+      val h2SQLBuilder = new H2SQLBuilder()
+      try {
+        Some(h2SQLBuilder.build(expr))
+      } catch {
+        case NonFatal(e) =>
+          logWarning("Error occurs while compiling V2 expression", e)
+          None
+      }
+    }
+
+    override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+      super.compileAggregate(aggFunction).orElse(
+        aggFunction match {
+          case f: UserDefinedAggregateFunc if f.name() == "iavg" =>
+            assert(f.children().length == 1)
+            val distinct = if (f.isDistinct) "DISTINCT " else ""
+            compileExpression(f.children().head).map(v => s"AVG($distinct$v)")
+          case _ => None
+        }
+      )
+    }
+
+    override def functions: Seq[(String, UnboundFunction)] = H2Dialect.functions
+  }
+
+  case object CharLength extends ScalarFunction[Int] {
+    override def inputTypes(): Array[DataType] = Array(StringType)
+    override def resultType(): DataType = IntegerType
+    override def name(): String = "CHAR_LENGTH"
+    override def canonicalName(): String = "h2.char_length"
+
+    override def produceResult(input: InternalRow): Int = {
+      val s = input.getString(0)
+      s.length
+    }
+  }
+
   override def sparkConf: SparkConf = super.sparkConf
     .set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
     .set("spark.sql.catalog.h2.url", url)
@@ -108,6 +184,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
         "(1, 'bottle', 99999999999999999999.123)").executeUpdate()
     }
     H2Dialect.registerFunction("my_avg", IntegralAverage)
+    H2Dialect.registerFunction("my_strlen", StrLen(CharLength))
   }
 
   override def afterAll(): Unit = {
@@ -961,6 +1038,33 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
     }
   }
 
+  test("scan with filter push-down with UDF") {
+    JdbcDialects.unregisterDialect(H2Dialect)
+    try {
+      JdbcDialects.registerDialect(testH2Dialect)
+      val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen(name) > 2")
+      checkFiltersRemoved(df1)
+      checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],")
+      checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2)))
+
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+        val df2 = sql(
+          """
+            |SELECT *
+            |FROM h2.test.people
+            |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2
+          """.stripMargin)
+        checkFiltersRemoved(df2)
+        checkPushedInfo(df2,
+          "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],")
+        checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
+      }
+    } finally {
+      JdbcDialects.unregisterDialect(testH2Dialect)
+      JdbcDialects.registerDialect(H2Dialect)
+    }
+  }
+
   test("scan with column pruning") {
     val df = spark.table("h2.test.people").select("id")
     checkSchemaNames(df, Seq("ID"))
@@ -1884,16 +1988,71 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
   }
 
   test("register dialect specific functions") {
-    val df = sql("SELECT h2.my_avg(id) FROM h2.test.people")
-    checkAggregateRemoved(df, false)
-    checkAnswer(df, Row(1) :: Nil)
-    val e1 = intercept[AnalysisException] {
-      checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty)
+    JdbcDialects.unregisterDialect(H2Dialect)
+    try {
+      JdbcDialects.registerDialect(testH2Dialect)
+      val df = sql("SELECT h2.my_avg(id) FROM h2.test.people")
+      checkAggregateRemoved(df)
+      checkAnswer(df, Row(1) :: Nil)
+      val e1 = intercept[AnalysisException] {
+        checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty)
+      }
+      assert(e1.getMessage.contains("Undefined function: h2.test.my_avg2"))
+      val e2 = intercept[AnalysisException] {
+        checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty)
+      }
+      assert(e2.getMessage.contains("Undefined function: h2.my_avg2"))
+    } finally {
+      JdbcDialects.unregisterDialect(testH2Dialect)
+      JdbcDialects.registerDialect(H2Dialect)
     }
-    assert(e1.getMessage.contains("Undefined function: h2.test.my_avg2"))
-    val e2 = intercept[AnalysisException] {
-      checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty)
+  }
+
+  test("scan with aggregate push-down: complete push-down UDAF") {
+    JdbcDialects.unregisterDialect(H2Dialect)
+    try {
+      JdbcDialects.registerDialect(testH2Dialect)
+      val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people")
+      checkAggregateRemoved(df1)
+      checkPushedInfo(df1,
+        "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: []")
+      checkAnswer(df1, Seq(Row(1)))
+
+      val df2 = sql("SELECT name, h2.my_avg(id) FROM h2.test.people group by name")
+      checkAggregateRemoved(df2)
+      checkPushedInfo(df2,
+        "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: [NAME]")
+      checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
+      withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+        val df3 = sql(
+          """
+            |SELECT
+            |  h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END)
+            |FROM h2.test.people
+          """.stripMargin)
+        checkAggregateRemoved(df3)
+        checkPushedInfo(df3,
+          "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," +
+            " PushedFilters: [], PushedGroupByExpressions: []")
+        checkAnswer(df3, Seq(Row(2)))
+
+        val df4 = sql(
+          """
+            |SELECT
+            |  name,
+            |  h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END)
+            |FROM h2.test.people
+            |GROUP BY name
+          """.stripMargin)
+        checkAggregateRemoved(df4)
+        checkPushedInfo(df4,
+          "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," +
+            " PushedFilters: [], PushedGroupByExpressions: [NAME]")
+        checkAnswer(df4, Seq(Row("fred", 2), Row("mary", 2)))
+      }
+    } finally {
+      JdbcDialects.unregisterDialect(testH2Dialect)
+      JdbcDialects.registerDialect(H2Dialect)
     }
-    assert(e2.getMessage.contains("Undefined function: h2.my_avg2"))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org