You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by su...@apache.org on 2022/01/19 22:10:55 UTC

[spark] branch branch-3.2 updated: [SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions

This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 3860ac5  [SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions
3860ac5 is described below

commit 3860ac5d2313d438c25e5c46ed4e2b8b2b5227e3
Author: Chao Sun <su...@apple.com>
AuthorDate: Wed Jan 19 14:07:30 2022 -0800

    [SPARK-37957][SQL] Correctly pass deterministic flag for V2 scalar functions
    
    ### What changes were proposed in this pull request?
    
    Pass `isDeterministic` flag to `ApplyFunctionExpression`, `Invoke` and `StaticInvoke` when processing V2 scalar functions.
    
    ### Why are the changes needed?
    
    A V2 scalar function can be declared as non-deterministic. However, currently Spark doesn't pass the flag when converting the V2 function to a catalyst expression, which could lead to incorrect results if being applied with certain optimizations.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added a unit test.
    
    Closes #35243 from sunchao/SPARK-37957.
    
    Authored-by: Chao Sun <su...@apple.com>
    Signed-off-by: Chao Sun <su...@apple.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     |   6 +-
 .../expressions/ApplyFunctionExpression.scala      |   2 +
 .../sql/catalyst/expressions/objects/objects.scala |  12 ++-
 .../connector/catalog/functions/JavaLongAdd.java   |   2 +-
 .../connector/catalog/functions/JavaRandomAdd.java | 110 +++++++++++++++++++++
 .../connector/catalog/functions/JavaStrLen.java    |   2 +-
 .../sql/connector/DataSourceV2FunctionSuite.scala  |  33 ++++++-
 7 files changed, 158 insertions(+), 9 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 89c7b5f..42bfa24 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2299,12 +2299,14 @@ class Analyzer(override val catalogManager: CatalogManager)
           case Some(m) if Modifier.isStatic(m.getModifiers) =>
             StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
               MAGIC_METHOD_NAME, arguments, inputTypes = declaredInputTypes,
-                propagateNull = false, returnNullable = scalarFunc.isResultNullable)
+                propagateNull = false, returnNullable = scalarFunc.isResultNullable,
+                isDeterministic = scalarFunc.isDeterministic)
           case Some(_) =>
             val caller = Literal.create(scalarFunc, ObjectType(scalarFunc.getClass))
             Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
               arguments, methodInputTypes = declaredInputTypes, propagateNull = false,
-              returnNullable = scalarFunc.isResultNullable)
+              returnNullable = scalarFunc.isResultNullable,
+              isDeterministic = scalarFunc.isDeterministic)
           case _ =>
             // TODO: handle functions defined in Scala too - in Scala, even if a
             //  subclass do not override the default method in parent interface
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
index b33b9ed..da4000f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ApplyFunctionExpression.scala
@@ -31,6 +31,8 @@ case class ApplyFunctionExpression(
   override def name: String = function.name()
   override def dataType: DataType = function.resultType()
   override def inputTypes: Seq[AbstractDataType] = function.inputTypes().toSeq
+  override lazy val deterministic: Boolean = function.isDeterministic &&
+      children.forall(_.deterministic)
 
   private lazy val reusedRow = new SpecificInternalRow(function.inputTypes())
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 50e2140..6d251b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -240,6 +240,8 @@ object SerializerSupport {
  *                      without invoking the function.
  * @param returnNullable When false, indicating the invoked method will always return
  *                       non-null value.
+ * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
+ *                        will not apply certain optimizations such as constant folding.
  */
 case class StaticInvoke(
     staticObject: Class[_],
@@ -248,7 +250,8 @@ case class StaticInvoke(
     arguments: Seq[Expression] = Nil,
     inputTypes: Seq[AbstractDataType] = Nil,
     propagateNull: Boolean = true,
-    returnNullable: Boolean = true) extends InvokeLike {
+    returnNullable: Boolean = true,
+    isDeterministic: Boolean = true) extends InvokeLike {
 
   val objectName = staticObject.getName.stripSuffix("$")
   val cls = if (staticObject.getName == objectName) {
@@ -259,6 +262,7 @@ case class StaticInvoke(
 
   override def nullable: Boolean = needNullCheck || returnNullable
   override def children: Seq[Expression] = arguments
+  override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
 
   lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
   @transient lazy val method = findMethod(cls, functionName, argClasses)
@@ -340,6 +344,8 @@ case class StaticInvoke(
  *                      without invoking the function.
  * @param returnNullable When false, indicating the invoked method will always return
  *                       non-null value.
+ * @param isDeterministic Whether the method invocation is deterministic or not. If false, Spark
+ *                        will not apply certain optimizations such as constant folding.
  */
 case class Invoke(
     targetObject: Expression,
@@ -348,12 +354,14 @@ case class Invoke(
     arguments: Seq[Expression] = Nil,
     methodInputTypes: Seq[AbstractDataType] = Nil,
     propagateNull: Boolean = true,
-    returnNullable : Boolean = true) extends InvokeLike {
+    returnNullable : Boolean = true,
+    isDeterministic: Boolean = true) extends InvokeLike {
 
   lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
 
   override def nullable: Boolean = targetObject.nullable || needNullCheck || returnNullable
   override def children: Seq[Expression] = targetObject +: arguments
+  override lazy val deterministic: Boolean = isDeterministic && arguments.forall(_.deterministic)
   override def inputTypes: Seq[AbstractDataType] =
     if (methodInputTypes.nonEmpty) {
       Seq(targetObject.dataType) ++ methodInputTypes
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
index e5b9c7f..75ef527 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaLongAdd.java
@@ -66,7 +66,7 @@ public class JavaLongAdd implements UnboundFunction {
     return "long_add";
   }
 
-  private abstract static class JavaLongAddBase implements ScalarFunction<Long> {
+  public abstract static class JavaLongAddBase implements ScalarFunction<Long> {
     private final boolean isResultNullable;
 
     JavaLongAddBase(boolean isResultNullable) {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java
new file mode 100644
index 0000000..b315faf
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaRandomAdd.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector.catalog.functions;
+
+import java.util.Random;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.connector.catalog.functions.BoundFunction;
+import org.apache.spark.sql.connector.catalog.functions.ScalarFunction;
+import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.IntegerType;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * Test V2 function which add a random number to the input integer.
+ */
+public class JavaRandomAdd implements UnboundFunction {
+  private final BoundFunction fn;
+
+  public JavaRandomAdd(BoundFunction fn) {
+    this.fn = fn;
+  }
+
+  @Override
+  public String name() {
+    return "rand";
+  }
+
+  @Override
+  public BoundFunction bind(StructType inputType) {
+    if (inputType.fields().length != 1) {
+      throw new UnsupportedOperationException("Expect exactly one argument");
+    }
+    if (inputType.fields()[0].dataType() instanceof IntegerType) {
+      return fn;
+    }
+    throw new UnsupportedOperationException("Expect IntegerType");
+  }
+
+  @Override
+  public String description() {
+    return "rand_add: add a random integer to the input\n" +
+      "rand_add(int) -> int";
+  }
+
+  public abstract static class JavaRandomAddBase implements ScalarFunction<Integer> {
+    @Override
+    public DataType[] inputTypes() {
+      return new DataType[] { DataTypes.IntegerType };
+    }
+
+    @Override
+    public DataType resultType() {
+      return DataTypes.IntegerType;
+    }
+
+    @Override
+    public String name() {
+      return "rand_add";
+    }
+
+    @Override
+    public boolean isDeterministic() {
+      return false;
+    }
+  }
+
+  public static class JavaRandomAddDefault extends JavaRandomAddBase {
+    private final Random rand = new Random();
+
+    @Override
+    public Integer produceResult(InternalRow input) {
+      return input.getInt(0) + rand.nextInt();
+    }
+  }
+
+  public static class JavaRandomAddMagic extends JavaRandomAddBase {
+    private final Random rand = new Random();
+
+    public int invoke(int input) {
+      return input + rand.nextInt();
+    }
+  }
+
+  public static class JavaRandomAddStaticMagic extends JavaRandomAddBase {
+    private static final Random rand = new Random();
+
+    public static int invoke(int input) {
+      return input + rand.nextInt();
+    }
+  }
+}
+
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
index 1b16896..dade2a1 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
@@ -49,7 +49,7 @@ public class JavaStrLen implements UnboundFunction {
       return fn;
     }
 
-    throw new UnsupportedOperationException("Except StringType");
+    throw new UnsupportedOperationException("Expect StringType");
   }
 
   @Override
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index d5417be..ace6619 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -20,16 +20,18 @@ package org.apache.spark.sql.connector
 import java.util
 import java.util.Collections
 
-import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
-import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.{JavaLongAddDefault, JavaLongAddMagic, JavaLongAddMismatchMagic, JavaLongAddStaticMagic}
+import test.org.apache.spark.sql.connector.catalog.functions._
+import test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd._
+import test.org.apache.spark.sql.connector.catalog.functions.JavaRandomAdd._
 import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
 
 import org.apache.spark.SparkException
-import org.apache.spark.sql.{AnalysisException, Row}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK, NO_CODEGEN}
 import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, InMemoryCatalog, SupportsNamespaces}
 import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction, _}
+import org.apache.spark.sql.execution.ProjectExec
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -365,6 +367,31 @@ class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
     }
   }
 
+  test("SPARK-37957: pass deterministic flag when creating V2 function expression") {
+    def checkDeterministic(df: DataFrame): Unit = {
+      val result = df.queryExecution.executedPlan.find(_.isInstanceOf[ProjectExec])
+      assert(result.isDefined, s"Expect to find ProjectExec")
+      assert(!result.get.asInstanceOf[ProjectExec].projectList.exists(_.deterministic),
+        "Expect expressions in projectList to be non-deterministic")
+    }
+
+    catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"), emptyProps)
+    Seq(new JavaRandomAddDefault, new JavaRandomAddMagic,
+        new JavaRandomAddStaticMagic).foreach { fn =>
+      addFunction(Identifier.of(Array("ns"), "rand_add"), new JavaRandomAdd(fn))
+      checkDeterministic(sql("SELECT testcat.ns.rand_add(42)"))
+    }
+
+    // A function call is non-deterministic if one of its arguments is non-deterministic
+    Seq(new JavaLongAddDefault(true), new JavaLongAddMagic(true),
+        new JavaLongAddStaticMagic(true)).foreach { fn =>
+      addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(fn))
+      addFunction(Identifier.of(Array("ns"), "rand_add"),
+        new JavaRandomAdd(new JavaRandomAddDefault))
+      checkDeterministic(sql("SELECT testcat.ns.add(10, testcat.ns.rand_add(42))"))
+    }
+  }
+
   private case class StrLen(impl: BoundFunction) extends UnboundFunction {
     override def description(): String =
       """strlen: returns the length of the input string

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org