You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by ck...@apache.org on 2022/11/18 08:52:19 UTC

[incubator-uniffle] branch branch-0.6 updated: [BUG] Fix incorrect spark metrics (#324)

This is an automated email from the ASF dual-hosted git repository.

ckj pushed a commit to branch branch-0.6
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/branch-0.6 by this push:
     new 1a7a3201 [BUG] Fix incorrect spark metrics (#324)
1a7a3201 is described below

commit 1a7a3201da5d477eba335744d6d23cfbed98ba62
Author: Junfan Zhang <zu...@apache.org>
AuthorDate: Fri Nov 18 11:02:27 2022 +0800

    [BUG] Fix incorrect spark metrics (#324)
    
    Fix incorrect spark metrics
    
    1. The corresponding shuffle-read records number and shuffle-write records number is not consistent in our internal cluster
    2. Log wont show the correct fetch bytes, always return 0 like
    
    `22/11/15 13:54:53 INFO RssShuffleDataIterator: Fetch 0 bytes cost 30791 ms and 53 ms to serialize, 347 ms to decompress with unCompressionLength[274815736]
    `
    
    No
    
    1. UTs
    2. Online spark3 jobs test
---
 .../shuffle/reader/RssShuffleDataIterator.java     | 15 ++--
 .../spark/shuffle/reader/RssShuffleReader.java     | 28 ++++++-
 .../spark/shuffle/reader/RssShuffleReader.java     |  5 +-
 .../uniffle/test/WriteAndReadMetricsTest.java      | 91 ++++++++++++++++++++++
 4 files changed, 131 insertions(+), 8 deletions(-)

diff --git a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 23e03641..6706efce 100644
--- a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -55,7 +55,8 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
   private Input deserializationInput = null;
   private DeserializationStream deserializationStream = null;
   private ByteBufInputStream byteBufInputStream = null;
-  private long unCompressionLength = 0;
+  private long compressedBytesLength = 0;
+  private long unCompressedBytesLength = 0;
   private ByteBuffer uncompressedData;
 
   public RssShuffleDataIterator(
@@ -108,8 +109,10 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
       long fetchDuration = System.currentTimeMillis() - startFetch;
       shuffleReadMetrics.incFetchWaitTime(fetchDuration);
       if (compressedData != null) {
-        shuffleReadMetrics.incRemoteBytesRead(compressedData.limit() - compressedData.position());
-        // Directbytebuffers are not collected in time will cause executor easy 
+        long compressedDataLength = compressedData.limit() - compressedData.position();
+        compressedBytesLength += compressedDataLength;
+        shuffleReadMetrics.incRemoteBytesRead(compressedDataLength);
+        // Directbytebuffers are not collected in time will cause executor easy
         // be killed by cluster managers(such as YARN) for using too much offheap memory
         if (uncompressedData != null && uncompressedData.isDirect()) {
           try {
@@ -121,7 +124,7 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
         long startDecompress = System.currentTimeMillis();
         uncompressedData = RssShuffleUtils.decompressData(
             compressedData, compressedBlock.getUncompressLength());
-        unCompressionLength += compressedBlock.getUncompressLength();
+        unCompressedBytesLength += compressedBlock.getUncompressLength();
         long decompressDuration = System.currentTimeMillis() - startDecompress;
         decompressTime += decompressDuration;
         // create new iterator for shuffle data
@@ -134,9 +137,9 @@ public class RssShuffleDataIterator<K, C> extends AbstractIterator<Product2<K, C
         // finish reading records, check data consistent
         shuffleReadClient.checkProcessedBlockIds();
         shuffleReadClient.logStatics();
-        LOG.info("Fetch " + shuffleReadMetrics.remoteBytesRead() + " bytes cost " + readTime + " ms and "
+        LOG.info("Fetch " + compressedBytesLength + " bytes cost " + readTime + " ms and "
             + serializeTime + " ms to serialize, " + decompressTime + " ms to decompress with unCompressionLength["
-            + unCompressionLength + "]");
+            + unCompressedBytesLength + "]");
         return false;
       }
     }
diff --git a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index ef97bea3..41d1f6c1 100644
--- a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -23,6 +23,8 @@ import org.apache.hadoop.conf.Configuration;
 import org.apache.spark.InterruptibleIterator;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.TaskContext;
+import org.apache.spark.executor.ShuffleReadMetrics;
+import org.apache.spark.executor.TempShuffleReadMetrics;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.shuffle.RssShuffleHandle;
 import org.apache.spark.shuffle.ShuffleReader;
@@ -113,11 +115,12 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
     ShuffleReadClient shuffleReadClient = ShuffleClientFactory.getInstance().createShuffleReadClient(request);
     RssShuffleDataIterator rssShuffleDataIterator = new RssShuffleDataIterator<K, C>(
         shuffleDependency.serializer(), shuffleReadClient,
-        context.taskMetrics().shuffleReadMetrics());
+        new ReadMetrics(context.taskMetrics().createTempShuffleReadMetrics()));
     CompletionIterator completionIterator =
         CompletionIterator$.MODULE$.apply(rssShuffleDataIterator, new AbstractFunction0<BoxedUnit>() {
           @Override
           public BoxedUnit apply() {
+            context.taskMetrics().mergeShuffleReadMetrics();
             return rssShuffleDataIterator.cleanup();
           }
         });
@@ -190,4 +193,27 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
   public Configuration getHadoopConf() {
     return hadoopConf;
   }
+
+  static class ReadMetrics extends ShuffleReadMetrics {
+    private TempShuffleReadMetrics tempShuffleReadMetrics;
+
+    ReadMetrics(TempShuffleReadMetrics tempShuffleReadMetric) {
+      this.tempShuffleReadMetrics = tempShuffleReadMetric;
+    }
+
+    @Override
+    public void incRemoteBytesRead(long v) {
+      tempShuffleReadMetrics.incRemoteBytesRead(v);
+    }
+
+    @Override
+    public void incFetchWaitTime(long v) {
+      tempShuffleReadMetrics.incFetchWaitTime(v);
+    }
+
+    @Override
+    public void incRecordsRead(long v) {
+      tempShuffleReadMetrics.incRecordsRead(v);
+    }
+  }
 }
diff --git a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
index a565cfe4..ce5a4590 100644
--- a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
+++ b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/reader/RssShuffleReader.java
@@ -203,7 +203,10 @@ public class RssShuffleReader<K, C> implements ShuffleReader<K, C> {
             shuffleDependency.serializer(), shuffleReadClient,
             readMetrics);
         CompletionIterator<Product2<K, C>, RssShuffleDataIterator<K, C>> completionIterator =
-            CompletionIterator$.MODULE$.apply(iterator, () -> iterator.cleanup());
+            CompletionIterator$.MODULE$.apply(iterator, () -> {
+              context.taskMetrics().mergeShuffleReadMetrics();
+              return iterator.cleanup();
+            });
         iterators.add(completionIterator);
       }
       iterator = iterators.iterator();
diff --git a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
new file mode 100644
index 00000000..b38b9fb3
--- /dev/null
+++ b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.lang.reflect.InvocationTargetException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SparkSession;
+import org.apache.spark.sql.functions;
+import org.apache.spark.status.AppStatusStore;
+import org.apache.spark.status.api.v1.StageData;
+import org.junit.jupiter.api.Test;
+import scala.collection.Seq;
+
+public class WriteAndReadMetricsTest extends SimpleTestBase {
+
+  @Test
+  public void test() throws Exception {
+    run();
+  }
+
+  @Override
+  public Map runTest(SparkSession spark, String fileName) throws Exception {
+    // take a rest to make sure shuffle server is registered
+    Thread.sleep(3000);
+
+    Dataset<Row> df1 = spark.range(0, 100, 1, 10)
+        .select(functions.when(functions.col("id").$less$eq(50), 1)
+            .otherwise(functions.col("id")).as("key1"), functions.col("id").as("value1"));
+    df1.createOrReplaceTempView("table1");
+
+    List list = spark.sql("select count(value1) from table1 group by key1").collectAsList();
+    Map<String, Long> result = new HashMap<>();
+    result.put("size", Long.valueOf(list.size()));
+
+    for (int stageId : spark.sparkContext().statusTracker().getJobInfo(0).get().stageIds()) {
+      long writeRecords = getFirstStageData(spark, stageId).shuffleWriteRecords();
+      long readRecords = getFirstStageData(spark, stageId).shuffleReadRecords();
+      result.put(stageId + "-write-records", writeRecords);
+      result.put(stageId + "-read-records", readRecords);
+    }
+
+    return result;
+  }
+
+  private StageData getFirstStageData(SparkSession spark, int stageId)
+      throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
+    AppStatusStore statestore = spark.sparkContext().statusStore();
+    try {
+      return ((Seq<StageData>)statestore
+          .getClass()
+          .getDeclaredMethod(
+              "stageData",
+              int.class,
+              boolean.class
+          ).invoke(statestore, stageId, false)).toList().head();
+    } catch (Exception e) {
+      return ((Seq<StageData>)statestore
+          .getClass()
+          .getDeclaredMethod(
+              "stageData",
+              int.class,
+              boolean.class,
+              List.class,
+              boolean.class,
+              double[].class
+          ).invoke(
+              statestore, stageId, false, new ArrayList<>(), true, new double[]{})).toList().head();
+    }
+  }
+}