You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by we...@apache.org on 2023/06/19 07:55:40 UTC

[spark] branch branch-3.4 updated: [SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression

This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 864b9869949 [SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression
864b9869949 is described below

commit 864b9869949dbf5dd538a2d5dc59f2894d72af1c
Author: Jiaan Geng <be...@163.com>
AuthorDate: Mon Jun 19 15:55:06 2023 +0800

    [SPARK-44018][SQL] Improve the hashCode and toString for some DS V2 Expression
    
    ### What changes were proposed in this pull request?
    The `hashCode() `of `UserDefinedScalarFunc` and `GeneralScalarExpression` is not good enough. Take for example, `GeneralScalarExpression` uses `Objects.hash(name, children)`, it adopt the hash code of `name` and `children`'s reference and then combine them together as the `GeneralScalarExpression`'s hash code.
    In fact, we should adopt the hash code for each element in `children`.
    
    Because `UserDefinedAggregateFunc` and `GeneralAggregateFunc` missing `hashCode()`, this PR also want add them.
    
    This PR also improve the toString for `UserDefinedAggregateFunc` and `GeneralAggregateFunc` by using bool primitive comparison instead `Objects.equals`. Because the performance of bool primitive comparison better than `Objects.equals`.
    
    ### Why are the changes needed?
    Improve the hash code for some DS V2 Expression.
    
    ### Does this PR introduce _any_ user-facing change?
    'Yes'.
    
    ### How was this patch tested?
    N/A
    
    Closes #41543 from beliefer/SPARK-44018.
    
    Authored-by: Jiaan Geng <be...@163.com>
    Signed-off-by: Wenchen Fan <we...@databricks.com>
    (cherry picked from commit 8c84d2c9349d7b607db949c2e114df781f23e438)
    Signed-off-by: Wenchen Fan <we...@databricks.com>
---
 .../expressions/GeneralScalarExpression.java       | 10 ++++++---
 .../expressions/UserDefinedScalarFunc.java         | 13 ++++++++----
 .../aggregate/GeneralAggregateFunc.java            | 22 ++++++++++++++++++++
 .../aggregate/UserDefinedAggregateFunc.java        | 24 ++++++++++++++++++++++
 4 files changed, 62 insertions(+), 7 deletions(-)

diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
index cb9bf6d69e2..85966060021 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.connector.expressions;
 
 import java.util.Arrays;
-import java.util.Objects;
 
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.filter.Predicate;
@@ -441,12 +440,17 @@ public class GeneralScalarExpression extends ExpressionWithToString {
   public boolean equals(Object o) {
     if (this == o) return true;
     if (o == null || getClass() != o.getClass()) return false;
+
     GeneralScalarExpression that = (GeneralScalarExpression) o;
-    return Objects.equals(name, that.name) && Arrays.equals(children, that.children);
+
+    if (!name.equals(that.name)) return false;
+    return Arrays.equals(children, that.children);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(name, children);
+    int result = name.hashCode();
+    result = 31 * result + Arrays.hashCode(children);
+    return result;
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
index b7f603cd431..cbf3941d77d 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/UserDefinedScalarFunc.java
@@ -18,7 +18,6 @@
 package org.apache.spark.sql.connector.expressions;
 
 import java.util.Arrays;
-import java.util.Objects;
 
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.internal.connector.ExpressionWithToString;
@@ -51,13 +50,19 @@ public class UserDefinedScalarFunc extends ExpressionWithToString {
   public boolean equals(Object o) {
     if (this == o) return true;
     if (o == null || getClass() != o.getClass()) return false;
+
     UserDefinedScalarFunc that = (UserDefinedScalarFunc) o;
-    return Objects.equals(name, that.name) && Objects.equals(canonicalName, that.canonicalName) &&
-      Arrays.equals(children, that.children);
+
+    if (!name.equals(that.name)) return false;
+    if (!canonicalName.equals(that.canonicalName)) return false;
+    return Arrays.equals(children, that.children);
   }
 
   @Override
   public int hashCode() {
-    return Objects.hash(name, canonicalName, children);
+    int result = name.hashCode();
+    result = 31 * result + canonicalName.hashCode();
+    result = 31 * result + Arrays.hashCode(children);
+    return result;
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
index 1abf3865659..4ef5b7f97e9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.connector.expressions.aggregate;
 
+import java.util.Arrays;
+
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.Expression;
 import org.apache.spark.sql.internal.connector.ExpressionWithToString;
@@ -60,4 +62,24 @@ public final class GeneralAggregateFunc extends ExpressionWithToString implement
 
   @Override
   public Expression[] children() { return children; }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    GeneralAggregateFunc that = (GeneralAggregateFunc) o;
+
+    if (isDistinct != that.isDistinct) return false;
+    if (!name.equals(that.name)) return false;
+    return Arrays.equals(children, that.children);
+  }
+
+  @Override
+  public int hashCode() {
+    int result = name.hashCode();
+    result = 31 * result + (isDistinct ? 1 : 0);
+    result = 31 * result + Arrays.hashCode(children);
+    return result;
+  }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
index d166ba16ba5..10a62d0478b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/UserDefinedAggregateFunc.java
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql.connector.expressions.aggregate;
 
+import java.util.Arrays;
+
 import org.apache.spark.annotation.Evolving;
 import org.apache.spark.sql.connector.expressions.Expression;
 import org.apache.spark.sql.internal.connector.ExpressionWithToString;
@@ -50,4 +52,26 @@ public class UserDefinedAggregateFunc extends ExpressionWithToString implements
 
   @Override
   public Expression[] children() { return children; }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    UserDefinedAggregateFunc that = (UserDefinedAggregateFunc) o;
+
+    if (isDistinct != that.isDistinct) return false;
+    if (!name.equals(that.name)) return false;
+    if (!canonicalName.equals(that.canonicalName)) return false;
+    return Arrays.equals(children, that.children);
+  }
+
+  @Override
+  public int hashCode() {
+    int result = name.hashCode();
+    result = 31 * result + canonicalName.hashCode();
+    result = 31 * result + (isDistinct ? 1 : 0);
+    result = 31 * result + Arrays.hashCode(children);
+    return result;
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org