You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/09/16 13:30:04 UTC

[spark] branch master updated: [SPARK-40398][CORE][SQL] Use Loop instead of Arrays.stream api

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d71b180295e [SPARK-40398][CORE][SQL] Use Loop instead of Arrays.stream api
d71b180295e is described below

commit d71b180295ea89b39047cff8397c5b3c2fe0bd20
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Fri Sep 16 08:29:31 2022 -0500

    [SPARK-40398][CORE][SQL] Use Loop instead of Arrays.stream api
    
    ### What changes were proposed in this pull request?
    This PR replaces `Arrays.stream` api with loop where performance improvement can be obtained.
    
    ### Why are the changes needed?
    Minor performance improvement.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Pass Github actions
    
    Closes #37843 from LuciferYang/ExpressionArrayToStrings.
    
    Authored-by: yangjie01 <ya...@baidu.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../network/shuffle/OneForOneBlockFetcher.java     | 24 ++++++++-
 .../sql/connector/expressions/Expression.java      | 20 +++++--
 .../sql/connector/metric/CustomAvgMetric.java      |  7 ++-
 .../sql/connector/metric/CustomSumMetric.java      |  8 +--
 .../sql/connector/util/V2ExpressionSQLBuilder.java | 62 +++++++++++++---------
 .../datasources/v2/V2PredicateSuite.scala          |  4 +-
 6 files changed, 87 insertions(+), 38 deletions(-)

diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
index a788b508e7b..b93db3f570b 100644
--- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
+++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
@@ -113,10 +113,30 @@ public class OneForOneBlockFetcher {
    * @return whether the array contains only shuffle block IDs
    */
   private boolean areShuffleBlocksOrChunks(String[] blockIds) {
-    if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
+    if (isAnyBlockNotStartWithShuffleBlockPrefix(blockIds)) {
       // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we
       // check if all the block ids are shuffle chunk Ids.
-      return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
+      return isAllBlocksStartWithShuffleChunkPrefix(blockIds);
+    }
+    return true;
+  }
+
+  // SPARK-40398: Replace `Arrays.stream().anyMatch()` with this method due to perf gain.
+  private static boolean isAnyBlockNotStartWithShuffleBlockPrefix(String[] blockIds) {
+    for (String blockId : blockIds) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX)) {
+        return true;
+      }
+    }
+    return false;
+  }
+
+  // SPARK-40398: Replace `Arrays.stream().allMatch()` with this method due to perf gain.
+  private static boolean isAllBlocksStartWithShuffleChunkPrefix(String[] blockIds) {
+    for (String blockId : blockIds) {
+      if (!blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
+        return false;
+      }
     }
     return true;
   }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
index 76dfe73f666..25953ec32e4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
@@ -17,7 +17,9 @@
 
 package org.apache.spark.sql.connector.expressions;
 
-import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
 
 import org.apache.spark.annotation.Evolving;
 
@@ -30,6 +32,13 @@ import org.apache.spark.annotation.Evolving;
 public interface Expression {
   Expression[] EMPTY_EXPRESSION = new Expression[0];
 
+  /**
+   * `EMPTY_EXPRESSION` is only used as an input when the
+   * default `references` method builds the result array to avoid
+   * repeatedly allocating an empty array.
+   */
+  NamedReference[] EMPTY_NAMED_REFERENCE = new NamedReference[0];
+
   /**
    * Format the expression as a human readable SQL-like string.
    */
@@ -44,7 +53,12 @@ public interface Expression {
    * List of fields or columns that are referenced by this expression.
    */
   default NamedReference[] references() {
-    return Arrays.stream(children()).map(e -> e.references())
-      .flatMap(Arrays::stream).distinct().toArray(NamedReference[]::new);
+    // SPARK-40398: Replace `Arrays.stream()...distinct()`
+    // to this for perf gain, the result order is not important.
+    Set<NamedReference> set = new HashSet<>();
+    for (Expression e : children()) {
+      Collections.addAll(set, e.references());
+    }
+    return set.toArray(EMPTY_NAMED_REFERENCE);
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java
index 71e83002dda..99ac3ac8d20 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomAvgMetric.java
@@ -19,7 +19,6 @@ package org.apache.spark.sql.connector.metric;
 
 import org.apache.spark.annotation.Evolving;
 
-import java.util.Arrays;
 import java.text.DecimalFormat;
 
 /**
@@ -33,7 +32,11 @@ public abstract class CustomAvgMetric implements CustomMetric {
   @Override
   public String aggregateTaskMetrics(long[] taskMetrics) {
     if (taskMetrics.length > 0) {
-      double average = ((double)Arrays.stream(taskMetrics).sum()) / taskMetrics.length;
+      long sum = 0L;
+      for (long taskMetric : taskMetrics) {
+        sum += taskMetric;
+      }
+      double average = ((double) sum) / taskMetrics.length;
       return new DecimalFormat("#0.000").format(average);
     } else {
       return "0";
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java
index ba28e9b9187..47d0ae9b805 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/metric/CustomSumMetric.java
@@ -19,8 +19,6 @@ package org.apache.spark.sql.connector.metric;
 
 import org.apache.spark.annotation.Evolving;
 
-import java.util.Arrays;
-
 /**
  * Built-in `CustomMetric` that sums up metric values. Note that please extend this class
  * and override `name` and `description` to create your custom metric for real usage.
@@ -31,6 +29,10 @@ import java.util.Arrays;
 public abstract class CustomSumMetric implements CustomMetric {
   @Override
   public String aggregateTaskMetrics(long[] taskMetrics) {
-    return String.valueOf(Arrays.stream(taskMetrics).sum());
+    long sum = 0L;
+    for (long taskMetric : taskMetrics) {
+      sum += taskMetric;
+    }
+    return String.valueOf(sum);
   }
 }
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
index 315b3309054..b32958d13da 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -17,10 +17,9 @@
 
 package org.apache.spark.sql.connector.util;
 
-import java.util.Arrays;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.StringJoiner;
-import java.util.stream.Collectors;
 
 import org.apache.spark.sql.connector.expressions.Cast;
 import org.apache.spark.sql.connector.expressions.Expression;
@@ -62,9 +61,9 @@ public class V2ExpressionSQLBuilder {
       String name = e.name();
       switch (name) {
         case "IN": {
-          List<String> children =
-            Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
-          return visitIn(children.get(0), children.subList(1, children.size()));
+          Expression[] expressions = e.children();
+          List<String> children = expressionsToStringList(expressions, 1, expressions.length - 1);
+          return visitIn(build(expressions[0]), children);
         }
         case "IS_NULL":
           return visitIsNull(build(e.children()[0]));
@@ -159,25 +158,18 @@ public class V2ExpressionSQLBuilder {
         case "BIT_LENGTH":
         case "CHAR_LENGTH":
         case "CONCAT":
-          return visitSQLFunction(name,
-            Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+          return visitSQLFunction(name, expressionsToStringArray(e.children()));
         case "CASE_WHEN": {
-          List<String> children =
-            Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
-          return visitCaseWhen(children.toArray(new String[e.children().length]));
+          return visitCaseWhen(expressionsToStringArray(e.children()));
         }
         case "TRIM":
-          return visitTrim("BOTH",
-            Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+          return visitTrim("BOTH", expressionsToStringArray(e.children()));
         case "LTRIM":
-          return visitTrim("LEADING",
-            Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+          return visitTrim("LEADING", expressionsToStringArray(e.children()));
         case "RTRIM":
-          return visitTrim("TRAILING",
-            Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+          return visitTrim("TRAILING", expressionsToStringArray(e.children()));
         case "OVERLAY":
-          return visitOverlay(
-            Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+          return visitOverlay(expressionsToStringArray(e.children()));
         // TODO supports other expressions
         default:
           return visitUnexpectedExpr(expr);
@@ -185,37 +177,37 @@ public class V2ExpressionSQLBuilder {
     } else if (expr instanceof Min) {
       Min min = (Min) expr;
       return visitAggregateFunction("MIN", false,
-        Arrays.stream(min.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(min.children()));
     } else if (expr instanceof Max) {
       Max max = (Max) expr;
       return visitAggregateFunction("MAX", false,
-        Arrays.stream(max.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(max.children()));
     } else if (expr instanceof Count) {
       Count count = (Count) expr;
       return visitAggregateFunction("COUNT", count.isDistinct(),
-        Arrays.stream(count.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(count.children()));
     } else if (expr instanceof Sum) {
       Sum sum = (Sum) expr;
       return visitAggregateFunction("SUM", sum.isDistinct(),
-        Arrays.stream(sum.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(sum.children()));
     } else if (expr instanceof CountStar) {
       return visitAggregateFunction("COUNT", false, new String[]{"*"});
     } else if (expr instanceof Avg) {
       Avg avg = (Avg) expr;
       return visitAggregateFunction("AVG", avg.isDistinct(),
-        Arrays.stream(avg.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(avg.children()));
     } else if (expr instanceof GeneralAggregateFunc) {
       GeneralAggregateFunc f = (GeneralAggregateFunc) expr;
       return visitAggregateFunction(f.name(), f.isDistinct(),
-        Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(f.children()));
     } else if (expr instanceof UserDefinedScalarFunc) {
       UserDefinedScalarFunc f = (UserDefinedScalarFunc) expr;
       return visitUserDefinedScalarFunction(f.name(), f.canonicalName(),
-        Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(f.children()));
     } else if (expr instanceof UserDefinedAggregateFunc) {
       UserDefinedAggregateFunc f = (UserDefinedAggregateFunc) expr;
       return visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(),
-        Arrays.stream(f.children()).map(c -> build(c)).toArray(String[]::new));
+        expressionsToStringArray(f.children()));
     } else {
       return visitUnexpectedExpr(expr);
     }
@@ -393,4 +385,22 @@ public class V2ExpressionSQLBuilder {
     }
     return joiner.toString();
   }
+
+  private String[] expressionsToStringArray(Expression[] expressions) {
+    String[] result = new String[expressions.length];
+    for (int i = 0; i < expressions.length; i++) {
+      result[i] = build(expressions[i]);
+    }
+    return result;
+  }
+
+  private List<String> expressionsToStringList(Expression[] expressions, int offset, int length) {
+    List<String> list = new ArrayList<>(length);
+    final int till = Math.min(offset + length, expressions.length);
+    while (offset < till) {
+      list.add(build(expressions[offset]));
+      offset++;
+    }
+    return list;
+  }
 }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
index de556c50f5d..a5fee51dc91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
@@ -263,7 +263,7 @@ class V2PredicateSuite extends SparkFunSuite {
       new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
       new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType))))
     assert(predicate1.equals(predicate2))
-    assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b"))
+    assert(predicate1.references.map(_.describe()).toSeq.sorted == Seq("a", "b"))
     assert(predicate1.describe.equals("(a = 1) AND (b = 1)"))
 
     val v1Filter = V1And(EqualTo("a", 1), EqualTo("b", 1))
@@ -287,7 +287,7 @@ class V2PredicateSuite extends SparkFunSuite {
       new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
       new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType))))
     assert(predicate1.equals(predicate2))
-    assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b"))
+    assert(predicate1.references.map(_.describe()).toSeq.sorted == Seq("a", "b"))
     assert(predicate1.describe.equals("(a = 1) OR (b = 1)"))
 
     val v1Filter = V1Or(EqualTo("a", 1), EqualTo("b", 1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org