You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by su...@apache.org on 2022/01/19 22:10:55 UTC
[spark] branch branch-3.2 updated: [SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions
This is an automated email from the ASF dual-hosted git repository.
sunchao pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.2 by this push:
new 3860ac5 [SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions
3860ac5 is described below
commit 3860ac5d2313d438c25e5c46ed4e2b8b2b5227e3
Author: Chao Sun <su...@apple.com>
AuthorDate: Wed Jan 19 14:07:30 2022 -0800
[SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions
### What changes were proposed in this pull request?
Pass `isDeterministic` flag to `ApplyFunctionExpression`, `Invoke` and `StaticInvoke` when processing V2 scalar functions.
### Why are the changes needed?
A V2 scalar function can be declared as non-deterministic. However, currently Spark doesn't pass the flag when converting the V2 function to a catalyst expression, which could lead to incorrect results if being applied with certain optimizations.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added a unit test.
Closes #35243 from sunchao/SPARK-37957.
Authored-by: Chao Sun <su...@apple.com>
Signed-off-by: Chao Sun <su...@apple.com>
---
.../spark/sql/catalyst/analysis/Analyzer.scala | 6 +-
.../expressions/ApplyFunctionExpression.scala | 2 +
.../sql/catalyst/expressions/objects/objects.scala | 12 ++-
.../connector/catalog/functions/JavaLongAdd.java | 2 +-
.../connector/catalog/functions/JavaRandomAdd.java | 110 +++++++++++++++++++++
.../connector/catalog/functions/JavaStrLen.java | 2 +-
.../sql/connector/DataSourceV2FunctionSuite.scala | 33 ++++++-
7 files changed, 158 insertions(+), 9 deletions(-)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 89c7b5f..42bfa24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2299,12 +2299,14 @@ class Analyzer(override val catalogManager: CatalogManager)
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
- propagateNull = false, returnNullable = scalarFunc.isResultNullable)
+ propagateNull = false, returnNullable = scalarFunc.isResultNullable,
+ isDeterministic = scalarFunc.isDeterministic)
case Some(_) =>
val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
arguments, methodInputTypes = declaredInputTypes, propagateNull = false,
- returnNullable = scalarFunc.isResultNullable)
+ returnNullable = scalarFunc.isResultNullable,
+ isDeterministic = scalarFunc.isDeterministic)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if a
// subclass do not override the default method in parent interface
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
index b33b9ed..da4000f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
@@ -31,6 +31,8 @@ case class ApplyFunctionExpression(
override def name: String = function.name()
override def dataType: DataType = function.resultType()
override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq
+ override lazy val deterministic: Boolean = function.isDeterministic &&
+ children.forall(_.deterministic)
private lazy val reusedRow = new SpecificInternalRow(function.inputTypes())
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 50e2140..6d251b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -240,6 +240,8 @@ object SerializerSupport {
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
+ * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
+ * will not apply certain optimizations such as constant folding.
*/
case class StaticInvoke(
staticObject: Class[_],
@@ -248,7 +250,8 @@ case class StaticInvoke(
arguments: Seq[Expression] = Nil,
inputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
- returnNullable: Boolean = true) extends InvokeLike {
+ returnNullable: Boolean = true,
+ isDeterministic: Boolean = true) extends InvokeLike {
val objectName = staticObject.getName.stripSuffix("$")
val cls = if (staticObject.getName == objectName) {
@@ -259,6 +262,7 @@ case class StaticInvoke(
override def nullable: Boolean = needNullCheck || returnNullable
override def children: Seq[Expression] = arguments
+ override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
@transient lazy val method = findMethod(cls, functionName, argClasses)
@@ -340,6 +344,8 @@ case class StaticInvoke(
* without invoking the function.
* @param returnNullable When false, indicating the invoked method will always return
* non-null value.
+ * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
+ * will not apply certain optimizations such as constant folding.
*/
case class Invoke(
targetObject: Expression,
@@ -348,12 +354,14 @@ case class Invoke(
arguments: Seq[Expression] = Nil,
methodInputTypes: Seq[AbstractDataType] = Nil,
propagateNull: Boolean = true,
- returnNullable : Boolean = true) extends InvokeLike {
+ returnNullable : Boolean = true,
+ isDeterministic: Boolean = true) extends InvokeLike {
lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
override def children: Seq[Expression] = targetObject +: arguments
+ override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
override def inputTypes: Seq[AbstractDataType] =
if (methodInputTypes.nonEmpty) {
Seq(targetObject.dataType) ++ methodInputTypes
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
index e5b9c7f..75ef527 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
@@ -66,7 +66,7 @@ public class JavaLongAdd implements UnboundFunction {
return "long_add";
}
- private abstract static class JavaLongAddBase implements ScalarFunction<Long> {
+ public abstract static class JavaLongAddBase implements ScalarFunction<Long> {
private final boolean isResultNullable;
JavaLongAddBase(boolean isResultNullable) {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java
new file mode 100644
index 0000000..b315faf
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector.catalog.functions;
+
+import java.util.Random;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * Test V2 function which add a random number to the input integer.
+ */
+public class JavaRandomAdd implements UnboundFunction {
+ private final BoundFunction fn;
+
+ public JavaRandomAdd(BoundFunction fn) {
+ this.fn = fn;
+ }
+
+ @Override
+ public String name() {
+ return "rand";
+ }
+
+ @Override
+ public BoundFunction bind(StructType inputType) {
+ if (inputType.fields().length != 1) {
+ throw new UnsupportedOperationException("Expect exactly one argument");
+ }
+ if (inputType.fields()[0].dataType() instanceof IntegerType) {
+ return fn;
+ }
+ throw new UnsupportedOperationException("Expect IntegerType");
+ }
+
+ @Override
+ public String description() {
+ return "rand_add: add a random integer to the input\n" +
+ "rand_add(int) -> int";
+ }
+
+ public abstract static class JavaRandomAddBase implements ScalarFunction<Integer> {
+ @Override
+ public DataType[] inputTypes() {
+ return new DataType[] { DataTypes.IntegerType };
+ }
+
+ @Override
+ public DataType resultType() {
+ return DataTypes.IntegerType;
+ }
+
+ @Override
+ public String name() {
+ return "rand_add";
+ }
+
+ @Override
+ public boolean isDeterministic() {
+ return false;
+ }
+ }
+
+ public static class JavaRandomAddDefault extends JavaRandomAddBase {
+ private final Random rand = new Random();
+
+ @Override
+ public Integer produceResult(InternalRow input) {
+ return input.getInt(0) + rand.nextInt();
+ }
+ }
+
+ public static class JavaRandomAddMagic extends JavaRandomAddBase {
+ private final Random rand = new Random();
+
+ public int invoke(int input) {
+ return input + rand.nextInt();
+ }
+ }
+
+ public static class JavaRandomAddStaticMagic extends JavaRandomAddBase {
+ private static final Random rand = new Random();
+
+ public static int invoke(int input) {
+ return input + rand.nextInt();
+ }
+ }
+}
+
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
index 1b16896..dade2a1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
@@ -49,7 +49,7 @@ public class JavaStrLen implements UnboundFunction {
return fn;
}
- throw new UnsupportedOperationException("Except StringType");
+ throw new UnsupportedOperationException("Expect StringType");
}
@Override
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index d5417be..ace6619 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -20,16 +20,18 @@ package org.apache.spark.sql.connector
import java.util
import java.util.Collections
-import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
-import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
+import test.org.apache.spark.sql.connector.catalog.functions._
+import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._
+import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}
+import org.apache.spark.sql.execution.ProjectExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -365,6 +367,31 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
}
}
+ test("SPARK-37957: pass deterministic flag when creating V2 function expression") {
+ def checkDeterministic(df: DataFrame): Unit = {
+ val result = df.queryExecution.executedPlan.find(_.isInstanceOf[ProjectExec])
+ assert(result.isDefined, s"Expect to find ProjectExec")
+ assert(!result.get.asInstanceOf[ProjectExec].projectList.exists(_.deterministic),
+ "Expect expressions in projectList to be non-deterministic")
+ }
+
+ catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+ Seq(new JavaRandomAddDefault, new JavaRandomAddMagic,
+ new JavaRandomAddStaticMagic).foreach { fn =>
+ addFunction(Identifier.of(Array("ns"), "rand_add"), new JavaRandomAdd(fn))
+ checkDeterministic(sql("SELECT testcat.ns.rand_add(42)"))
+ }
+
+ // A function call is non-deterministic if one of its arguments is non-deterministic
+ Seq(new JavaLongAddDefault(true), new JavaLongAddMagic(true),
+ new JavaLongAddStaticMagic(true)).foreach { fn =>
+ addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(fn))
+ addFunction(Identifier.of(Array("ns"), "rand_add"),
+ new JavaRandomAdd(new JavaRandomAddDefault))
+ checkDeterministic(sql("SELECT testcat.ns.add(10, testcat.ns.rand_add(42))"))
+ }
+ }
+
private case class StrLen(impl: BoundFunction) extends UnboundFunction {
override def description(): String =
"""strlen: returns the length of the input string
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org