You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2022/06/30 10:03:19 UTC
[spark] branch master updated: [SPARK-39139][SQL] DS V2 supports push down DS V2 UDF
This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b242f85547a [SPARK-39139][SQL] DS V2 supports push down DS V2 UDF
b242f85547a is described below
commit b242f85547a50376a583cb19a607587aafeaa1f8
Author: Jiaan Geng <be...@163.com>
AuthorDate: Thu Jun 30 18:02:53 2022 +0800
[SPARK-39139][SQL] DS V2 supports push down DS V2 UDF
### What changes were proposed in this pull request?
Currently, Spark DS V2 push-down framework supports push down SQL to data sources.
But the DS V2 push-down framework only support push down the built-in functions to data sources.
Each database have a lot very useful functions which not supported by Spark.
If we can push down these functions into data source, it will reduce disk I/O and network I/O and improve the performance when query databases.
### Why are the changes needed?
1. Spark can leverage the functions supported by databases
2. Improve the query performance.
### Does this PR introduce _any_ user-facing change?
'No'.
New feature.
### How was this patch tested?
New tests.
Closes #36593 from beliefer/SPARK-39139.
Authored-by: Jiaan Geng <be...@163.com>
Signed-off-by: Wenchen Fan <we...@databricks.com>
---
.../expressions/GeneralScalarExpression.java | 11 +-
.../expressions/UserDefinedScalarFunc.java | 70 ++++++++
.../aggregate/GeneralAggregateFunc.java | 20 +--
...gateFunc.java => UserDefinedAggregateFunc.java} | 48 ++----
.../sql/connector/util/V2ExpressionSQLBuilder.java | 37 +++++
.../internal/connector/ToStringSQLBuilder.scala | 38 +++++
.../sql/catalyst/util/V2ExpressionBuilder.scala | 10 +-
.../execution/datasources/DataSourceStrategy.scala | 10 +-
.../datasources/v2/V2ScanRelationPushDown.scala | 3 +-
.../sql/connector/DataSourceV2FunctionSuite.scala | 38 ++---
.../org/apache/spark/sql/jdbc/JDBCV2Suite.scala | 179 +++++++++++++++++++--
11 files changed, 380 insertions(+), 84 deletions(-)
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index d4c27196eaf..ab9e33e86be 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -23,7 +23,7 @@ import java.util.Objects;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.filter.Predicate;
-import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
/**
* The general representation of SQL scalar expressions, which contains the upper-cased
@@ -381,12 +381,7 @@ public class GeneralScalarExpression implements Expression, Serializable {
@Override
public String toString() {
- V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder();
- try {
- return builder.build(this);
- } catch (Throwable e) {
- return name + "(" +
- Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")";
- }
+ ToStringSQLBuilder builder = new ToStringSQLBuilder();
+ return builder.build(this);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
new file mode 100644
index 00000000000..8e4155f81b8
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Objects;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
+
+/**
+ * The general representation of user defined scalar function, which contains the upper-cased
+ * function name, canonical function name and all the children expressions.
+ *
+ * @since 3.4.0
+ */
+@Evolving
+public class UserDefinedScalarFunc implements Expression, Serializable {
+ private String name;
+ private String canonicalName;
+ private Expression[] children;
+
+ public UserDefinedScalarFunc(String name, String canonicalName, Expression[] children) {
+ this.name = name;
+ this.canonicalName = canonicalName;
+ this.children = children;
+ }
+
+ public String name() { return name; }
+ public String canonicalName() { return canonicalName; }
+
+ @Override
+ public Expression[] children() { return children; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
+ return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
+ Arrays.equals(children, that.children);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, canonicalName, children);
+ }
+
+ @Override
+ public String toString() {
+ ToStringSQLBuilder builder = new ToStringSQLBuilder();
+ return builder.build(this);
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 70166445434..81838074fb1 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -17,11 +17,9 @@
package org.apache.spark.sql.connector.expressions.aggregate;
-import java.util.Arrays;
-import java.util.stream.Collectors;
-
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
/**
* The general implementation of {@link AggregateFunc}, which contains the upper-cased function
@@ -47,27 +45,21 @@ public final class GeneralAggregateFunc implements AggregateFunc {
private final boolean isDistinct;
private final Expression[] children;
- public String name() { return name; }
- public boolean isDistinct() { return isDistinct; }
-
public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
this.name = name;
this.isDistinct = isDistinct;
this.children = children;
}
+ public String name() { return name; }
+ public boolean isDistinct() { return isDistinct; }
+
@Override
public Expression[] children() { return children; }
@Override
public String toString() {
- String inputsString = Arrays.stream(children)
- .map(Expression::describe)
- .collect(Collectors.joining(", "));
- if (isDistinct) {
- return name + "(DISTINCT " + inputsString + ")";
- } else {
- return name + "(" + inputsString + ")";
- }
+ ToStringSQLBuilder builder = new ToStringSQLBuilder();
+ return builder.build(this);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
similarity index 52%
copy from sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
copy to sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
index 70166445434..9a89e7a89c9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
@@ -17,57 +17,43 @@
package org.apache.spark.sql.connector.expressions.aggregate;
-import java.util.Arrays;
-import java.util.stream.Collectors;
-
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.internal.connector.ToStringSQLBuilder;
/**
- * The general implementation of {@link AggregateFunc}, which contains the upper-cased function
- * name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial
- * aggregate with this function to the source, but can only push down the entire aggregate.
- * <p>
- * The currently supported SQL aggregate functions:
- * <ol>
- * <li><pre>VAR_POP(input1)</pre> Since 3.3.0</li>
- * <li><pre>VAR_SAMP(input1)</pre> Since 3.3.0</li>
- * <li><pre>STDDEV_POP(input1)</pre> Since 3.3.0</li>
- * <li><pre>STDDEV_SAMP(input1)</pre> Since 3.3.0</li>
- * <li><pre>COVAR_POP(input1, input2)</pre> Since 3.3.0</li>
- * <li><pre>COVAR_SAMP(input1, input2)</pre> Since 3.3.0</li>
- * <li><pre>CORR(input1, input2)</pre> Since 3.3.0</li>
- * </ol>
+ * The general representation of user defined aggregate function, which implements
+ * {@link AggregateFunc}, contains the upper-cased function name, the canonical function name,
+ * the `isDistinct` flag and all the inputs. Note that Spark cannot push down aggregate with
+ * this function partially to the source, but can only push down the entire aggregate.
*
- * @since 3.3.0
+ * @since 3.4.0
*/
@Evolving
-public final class GeneralAggregateFunc implements AggregateFunc {
+public class UserDefinedAggregateFunc implements AggregateFunc {
private final String name;
+ private String canonicalName;
private final boolean isDistinct;
private final Expression[] children;
- public String name() { return name; }
- public boolean isDistinct() { return isDistinct; }
-
- public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
+ public UserDefinedAggregateFunc(
+ String name, String canonicalName, boolean isDistinct, Expression[] children) {
this.name = name;
+ this.canonicalName = canonicalName;
this.isDistinct = isDistinct;
this.children = children;
}
+ public String name() { return name; }
+ public String canonicalName() { return canonicalName; }
+ public boolean isDistinct() { return isDistinct; }
+
@Override
public Expression[] children() { return children; }
@Override
public String toString() {
- String inputsString = Arrays.stream(children)
- .map(Expression::describe)
- .collect(Collectors.joining(", "));
- if (isDistinct) {
- return name + "(DISTINCT " + inputsString + ")";
- } else {
- return name + "(" + inputsString + ")";
- }
+ ToStringSQLBuilder builder = new ToStringSQLBuilder();
+ return builder.build(this);
}
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 8ae70a7529c..9b62fedcc80 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -26,6 +26,9 @@ import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
+import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
+import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
import org.apache.spark.sql.types.DataType;
/**
@@ -156,6 +159,18 @@ public class V2ExpressionSQLBuilder {
default:
return visitUnexpectedExpr(expr);
}
+ } else if (expr instanceof GeneralAggregateFunc) {
+ GeneralAggregateFunc f = (GeneralAggregateFunc) expr;
+ return visitGeneralAggregateFunction(f.name(), f.isDistinct(),
+ Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+ } else if (expr instanceof UserDefinedScalarFunc) {
+ UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr;
+ return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
+ Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+ } else if (expr instanceof UserDefinedAggregateFunc) {
+ UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr;
+ return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(),
+ Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
} else {
return visitUnexpectedExpr(expr);
}
@@ -268,6 +283,28 @@ public class V2ExpressionSQLBuilder {
return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
}
+ protected String visitGeneralAggregateFunction(
+ String funcName, boolean isDistinct, String[] inputs) {
+ if (isDistinct) {
+ return funcName +
+ "(DISTINCT " + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
+ } else {
+ return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
+ }
+ }
+
+ protected String visitUserDefinedScalarFunction(
+ String funcName, String canonicalName, String[] inputs) {
+ throw new UnsupportedOperationException(
+ this.getClass().getSimpleName() + " does not support user defined function: " + funcName);
+ }
+
+ protected String visitUserDefinedAggregateFunction(
+ String funcName, String canonicalName, boolean isDistinct, String[] inputs) {
+ throw new UnsupportedOperationException(this.getClass().getSimpleName() +
+ " does not support user defined aggregate function: " + funcName);
+ }
+
protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
new file mode 100644
index 00000000000..889fdd4ebf2
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ToStringSQLBuilder.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
+
+/**
+ * The builder to generate `toString` information of V2 expressions.
+ */
+class ToStringSQLBuilder extends V2ExpressionSQLBuilder {
+ override protected def visitUserDefinedScalarFunction(
+ funcName: String, canonicalName: String, inputs: Array[String]) =
+ s"""$funcName(${inputs.mkString(", ")})"""
+
+ override protected def visitUserDefinedAggregateFunction(
+ funcName: String,
+ canonicalName: String,
+ isDistinct: Boolean,
+ inputs: Array[String]): String = {
+ val distinct = if (isDistinct) "DISTINCT " else ""
+ s"""$funcName($distinct${inputs.mkString(", ")})"""
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 1a9b9202cbe..163e071f08e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc}
import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
import org.apache.spark.sql.types.BooleanType
@@ -345,6 +345,14 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) {
None
}
// TODO supports other expressions
+ case ApplyFunctionExpression(function, children) =>
+ val childrenExpressions = children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == children.length) {
+ Some(new UserDefinedScalarFunc(
+ function.name(), function.canonicalName(), childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
case _ => None
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index ba4c76285d5..8882261d967 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, V2ExpressionBu
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
@@ -751,6 +751,14 @@ object DataSourceStrategy
PushableColumnWithoutNestedColumn(right), _) =>
Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
Array(FieldReference.column(left), FieldReference.column(right))))
+ case aggregate.V2Aggregator(aggrFunc, children, _, _) =>
+ val translatedExprs = children.flatMap(PushableExpression.unapply(_))
+ if (translatedExprs.length == children.length) {
+ Some(new UserDefinedAggregateFunc(aggrFunc.name(),
+ aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression]))
+ } else {
+ None
+ }
case _ => None
}
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index b55aeefca0e..8951c37e127 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
-import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum, UserDefinedAggregateFunc}
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
@@ -299,6 +299,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper wit
case count: Count => !count.isDistinct
case avg: Avg => !avg.isDistinct
case _: GeneralAggregateFunc => false
+ case _: UserDefinedAggregateFunc => false
case _ => true
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index cb9aea914e1..4a070ace377 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -37,6 +37,7 @@ import org.apache.spark.unsafe.types.UTF8String
object IntAverage extends AggregateFunction[(Int, Int), Int] {
override def name(): String = "iavg"
+ override def canonicalName(): String = "h2.iavg"
override def inputTypes(): Array[DataType] = Array(IntegerType)
override def resultType(): DataType = IntegerType
@@ -65,6 +66,7 @@ object IntAverage extends AggregateFunction[(Int, Int), Int] {
object LongAverage extends AggregateFunction[(Long, Long), Long] {
override def name(): String = "iavg"
+ override def canonicalName(): String = "h2.iavg"
override def inputTypes(): Array[DataType] = Array(LongType)
override def resultType(): DataType = LongType
@@ -113,6 +115,24 @@ object IntegralAverage extends UnboundFunction {
| iavg(bigint) -> bigint""".stripMargin
}
+case class StrLen(impl: BoundFunction) extends UnboundFunction {
+ override def name(): String = "strlen"
+
+ override def bind(inputType: StructType): BoundFunction = {
+ if (inputType.fields.length != 1) {
+ throw new UnsupportedOperationException("Expect exactly one argument");
+ }
+ inputType.fields(0).dataType match {
+ case StringType => impl
+ case _ =>
+ throw new UnsupportedOperationException("Expect StringType")
+ }
+ }
+
+ override def description(): String =
+ "strlen: returns the length of the input string strlen(string) -> int"
+}
+
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String]
@@ -532,24 +552,6 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
}
}
- private case class StrLen(impl: BoundFunction) extends UnboundFunction {
- override def description(): String =
- """strlen: returns the length of the input string
- | strlen(string) -> int""".stripMargin
- override def name(): String = "strlen"
-
- override def bind(inputType: StructType): BoundFunction = {
- if (inputType.fields.length != 1) {
- throw new UnsupportedOperationException("Expect exactly one argument");
- }
- inputType.fields(0).dataType match {
- case StringType => impl
- case _ =>
- throw new UnsupportedOperationException("Expect StringType")
- }
- }
- }
-
private case object StrLenDefault extends ScalarFunction[Int] {
override def inputTypes(): Array[DataType] = Array(StringType)
override def resultType(): DataType = IntegerType
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index c625f55ef8b..90ab976d9d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -20,16 +20,23 @@ package org.apache.spark.sql.jdbc
import java.sql.{Connection, DriverManager}
import java.util.Properties
+import scala.util.control.NonFatal
+
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.sql.{AnalysisException, DataFrame, ExplainSuiteHelper, QueryTest, Row}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Offset, Sort}
-import org.apache.spark.sql.connector.IntegralAverage
+import org.apache.spark.sql.connector.{IntegralAverage, StrLen}
+import org.apache.spark.sql.connector.catalog.functions.{ScalarFunction, UnboundFunction}
+import org.apache.spark.sql.connector.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, UserDefinedAggregateFunc}
import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
import org.apache.spark.sql.functions.{abs, acos, asin, atan, atan2, avg, ceil, coalesce, cos, cosh, cot, count, count_distinct, degrees, exp, floor, lit, log => logarithm, log10, not, pow, radians, round, signum, sin, sinh, sqrt, sum, tan, tanh, udf, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DataType, IntegerType, StringType}
import org.apache.spark.util.Utils
class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHelper {
@@ -39,6 +46,75 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val url = s"jdbc:h2:${tempDir.getCanonicalPath};user=testUser;password=testPass"
var conn: java.sql.Connection = null
+ val testH2Dialect = new JdbcDialect {
+ override def canHandle(url: String): Boolean = H2Dialect.canHandle(url)
+
+ class H2SQLBuilder extends JDBCSQLBuilder {
+ override def visitUserDefinedScalarFunction(
+ funcName: String, canonicalName: String, inputs: Array[String]): String = {
+ canonicalName match {
+ case "h2.char_length" =>
+ s"$funcName(${inputs.mkString(", ")})"
+ case _ => super.visitUserDefinedScalarFunction(funcName, canonicalName, inputs)
+ }
+ }
+
+ override def visitUserDefinedAggregateFunction(
+ funcName: String,
+ canonicalName: String,
+ isDistinct: Boolean,
+ inputs: Array[String]): String = {
+ canonicalName match {
+ case "h2.iavg" =>
+ if (isDistinct) {
+ s"$funcName(DISTINCT ${inputs.mkString(", ")})"
+ } else {
+ s"$funcName(${inputs.mkString(", ")})"
+ }
+ case _ =>
+ super.visitUserDefinedAggregateFunction(funcName, canonicalName, isDistinct, inputs)
+ }
+ }
+ }
+
+ override def compileExpression(expr: Expression): Option[String] = {
+ val h2SQLBuilder = new H2SQLBuilder()
+ try {
+ Some(h2SQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: UserDefinedAggregateFunc if f.name() == "iavg" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ compileExpression(f.children().head).map(v => s"AVG($distinct$v)")
+ case _ => None
+ }
+ )
+ }
+
+ override def functions: Seq[(String, UnboundFunction)] = H2Dialect.functions
+ }
+
+ case object CharLength extends ScalarFunction[Int] {
+ override def inputTypes(): Array[DataType] = Array(StringType)
+ override def resultType(): DataType = IntegerType
+ override def name(): String = "CHAR_LENGTH"
+ override def canonicalName(): String = "h2.char_length"
+
+ override def produceResult(input: InternalRow): Int = {
+ val s = input.getString(0)
+ s.length
+ }
+ }
+
override def sparkConf: SparkConf = super.sparkConf
.set("spark.sql.catalog.h2", classOf[JDBCTableCatalog].getName)
.set("spark.sql.catalog.h2.url", url)
@@ -108,6 +184,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
"(1, 'bottle', 99999999999999999999.123)").executeUpdate()
}
H2Dialect.registerFunction("my_avg", IntegralAverage)
+ H2Dialect.registerFunction("my_strlen", StrLen(CharLength))
}
override def afterAll(): Unit = {
@@ -961,6 +1038,33 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}
+ test("scan with filter push-down with UDF") {
+ JdbcDialects.unregisterDialect(H2Dialect)
+ try {
+ JdbcDialects.registerDialect(testH2Dialect)
+ val df1 = sql("SELECT * FROM h2.test.people where h2.my_strlen(name) > 2")
+ checkFiltersRemoved(df1)
+ checkPushedInfo(df1, "PushedFilters: [CHAR_LENGTH(NAME) > 2],")
+ checkAnswer(df1, Seq(Row("fred", 1), Row("mary", 2)))
+
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ val df2 = sql(
+ """
+ |SELECT *
+ |FROM h2.test.people
+ |WHERE h2.my_strlen(CASE WHEN NAME = 'fred' THEN NAME ELSE "abc" END) > 2
+ """.stripMargin)
+ checkFiltersRemoved(df2)
+ checkPushedInfo(df2,
+ "PushedFilters: [CHAR_LENGTH(CASE WHEN NAME = 'fred' THEN NAME ELSE 'abc' END) > 2],")
+ checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
+ }
+ } finally {
+ JdbcDialects.unregisterDialect(testH2Dialect)
+ JdbcDialects.registerDialect(H2Dialect)
+ }
+ }
+
test("scan with column pruning") {
val df = spark.table("h2.test.people").select("id")
checkSchemaNames(df, Seq("ID"))
@@ -1884,16 +1988,71 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
test("register dialect specific functions") {
- val df = sql("SELECT h2.my_avg(id) FROM h2.test.people")
- checkAggregateRemoved(df, false)
- checkAnswer(df, Row(1) :: Nil)
- val e1 = intercept[AnalysisException] {
- checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty)
+ JdbcDialects.unregisterDialect(H2Dialect)
+ try {
+ JdbcDialects.registerDialect(testH2Dialect)
+ val df = sql("SELECT h2.my_avg(id) FROM h2.test.people")
+ checkAggregateRemoved(df)
+ checkAnswer(df, Row(1) :: Nil)
+ val e1 = intercept[AnalysisException] {
+ checkAnswer(sql("SELECT h2.test.my_avg2(id) FROM h2.test.people"), Seq.empty)
+ }
+ assert(e1.getMessage.contains("Undefined function: h2.test.my_avg2"))
+ val e2 = intercept[AnalysisException] {
+ checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty)
+ }
+ assert(e2.getMessage.contains("Undefined function: h2.my_avg2"))
+ } finally {
+ JdbcDialects.unregisterDialect(testH2Dialect)
+ JdbcDialects.registerDialect(H2Dialect)
}
- assert(e1.getMessage.contains("Undefined function: h2.test.my_avg2"))
- val e2 = intercept[AnalysisException] {
- checkAnswer(sql("SELECT h2.my_avg2(id) FROM h2.test.people"), Seq.empty)
+ }
+
+ test("scan with aggregate push-down: complete push-down UDAF") {
+ JdbcDialects.unregisterDialect(H2Dialect)
+ try {
+ JdbcDialects.registerDialect(testH2Dialect)
+ val df1 = sql("SELECT h2.my_avg(id) FROM h2.test.people")
+ checkAggregateRemoved(df1)
+ checkPushedInfo(df1,
+ "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: []")
+ checkAnswer(df1, Seq(Row(1)))
+
+ val df2 = sql("SELECT name, h2.my_avg(id) FROM h2.test.people group by name")
+ checkAggregateRemoved(df2)
+ checkPushedInfo(df2,
+ "PushedAggregates: [iavg(ID)], PushedFilters: [], PushedGroupByExpressions: [NAME]")
+ checkAnswer(df2, Seq(Row("fred", 1), Row("mary", 2)))
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") {
+ val df3 = sql(
+ """
+ |SELECT
+ | h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END)
+ |FROM h2.test.people
+ """.stripMargin)
+ checkAggregateRemoved(df3)
+ checkPushedInfo(df3,
+ "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," +
+ " PushedFilters: [], PushedGroupByExpressions: []")
+ checkAnswer(df3, Seq(Row(2)))
+
+ val df4 = sql(
+ """
+ |SELECT
+ | name,
+ | h2.my_avg(CASE WHEN NAME = 'fred' THEN id + 1 ELSE id END)
+ |FROM h2.test.people
+ |GROUP BY name
+ """.stripMargin)
+ checkAggregateRemoved(df4)
+ checkPushedInfo(df4,
+ "PushedAggregates: [iavg(CASE WHEN NAME = 'fred' THEN ID + 1 ELSE ID END)]," +
+ " PushedFilters: [], PushedGroupByExpressions: [NAME]")
+ checkAnswer(df4, Seq(Row("fred", 2), Row("mary", 2)))
+ }
+ } finally {
+ JdbcDialects.unregisterDialect(testH2Dialect)
+ JdbcDialects.registerDialect(H2Dialect)
}
- assert(e2.getMessage.contains("Undefined function: h2.my_avg2"))
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org