You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by xi...@apache.org on 2023/02/10 03:47:37 UTC

[incubator-uniffle] branch master updated: [#575] refactor: replace switch-case with EnumMap in ComposedClientReadHandler (#570)

This is an automated email from the ASF dual-hosted git repository.

xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 9da9bb66 [#575] refactor: replace switch-case with EnumMap in ComposedClientReadHandler (#570)
9da9bb66 is described below

commit 9da9bb66f6ece6cf0b89dc931129edda63e149c6
Author: Kaijie Chen <ck...@apache.org>
AuthorDate: Fri Feb 10 11:47:32 2023 +0800

    [#575] refactor: replace switch-case with EnumMap in ComposedClientReadHandler (#570)
    
    ### What changes were proposed in this pull request?
    Add enum Tier, use EnumMap to replace switch-case in ComposedClientReadHandler.
    
    ### Why are the changes needed?
    Make the code more concise.
    
    ### Does this PR introduce any user-facing change?
    No.
    
    ### How was this patch tested?
    Existing CI.
---
 .../test/ShuffleServerWithMemLocalHdfsTest.java    |  41 ++--
 .../storage/factory/ShuffleHandlerFactory.java     |   4 +-
 .../handler/impl/ComposedClientReadHandler.java    | 234 +++++++--------------
 3 files changed, 97 insertions(+), 182 deletions(-)

diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
index c27fe6f9..5fc74979 100644
--- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
@@ -52,6 +52,7 @@ import org.apache.uniffle.storage.util.StorageType;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertNotNull;
 import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
 import static org.junit.jupiter.api.Assertions.fail;
 
 public class ShuffleServerWithMemLocalHdfsTest extends ShuffleReadWriteBase {
@@ -211,31 +212,31 @@ public class ShuffleServerWithMemLocalHdfsTest extends ShuffleReadWriteBase {
     assertNull(sdr);
 
     if (checkSkippedMetrics) {
-      String readBlokNumInfo = composedClientReadHandler.getReadBlokNumInfo();
-      assert (readBlokNumInfo.contains("Client read 0 blocks from [" + ssi + "]")
-          && readBlokNumInfo.contains("Skipped[ hot:3 warm:3 cold:2 frozen:0 ]")
-          && readBlokNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
+      String readBlockNumInfo = composedClientReadHandler.getReadBlockNumInfo();
+      assertTrue(readBlockNumInfo.contains("Client read 0 blocks from [" + ssi + "]"));
+      assertTrue(readBlockNumInfo.contains("Skipped[ hot:3 warm:3 cold:2 frozen:0 ]"));
+      assertTrue(readBlockNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
       String readLengthInfo = composedClientReadHandler.getReadLengthInfo();
-      assert (readLengthInfo.contains("Client read 0 bytes from [" + ssi + "]")
-          && readLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]")
-          && readBlokNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
+      assertTrue(readLengthInfo.contains("Client read 0 bytes from [" + ssi + "]"));
+      assertTrue(readLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]"));
+      assertTrue(readBlockNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
       String readUncompressLengthInfo = composedClientReadHandler.getReadUncompressLengthInfo();
-      assert (readUncompressLengthInfo.contains("Client read 0 uncompressed bytes from [" + ssi + "]")
-          && readUncompressLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]")
-          && readBlokNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
+      assertTrue(readUncompressLengthInfo.contains("Client read 0 uncompressed bytes from [" + ssi + "]"));
+      assertTrue(readUncompressLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]"));
+      assertTrue(readBlockNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
     } else {
-      String readBlokNumInfo = composedClientReadHandler.getReadBlokNumInfo();
-      assert (readBlokNumInfo.contains("Client read 8 blocks from [" + ssi + "]")
-          && readBlokNumInfo.contains("Consumed[ hot:3 warm:3 cold:2 frozen:0 ]")
-          && readBlokNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
+      String readBlockNumInfo = composedClientReadHandler.getReadBlockNumInfo();
+      assertTrue(readBlockNumInfo.contains("Client read 8 blocks from [" + ssi + "]"));
+      assertTrue(readBlockNumInfo.contains("Consumed[ hot:3 warm:3 cold:2 frozen:0 ]"));
+      assertTrue(readBlockNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
       String readLengthInfo = composedClientReadHandler.getReadLengthInfo();
-      assert (readLengthInfo.contains("Client read 625 bytes from [" + ssi + "]")
-          && readLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]")
-          && readBlokNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
+      assertTrue(readLengthInfo.contains("Client read 625 bytes from [" + ssi + "]"));
+      assertTrue(readLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]"));
+      assertTrue(readBlockNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
       String readUncompressLengthInfo = composedClientReadHandler.getReadUncompressLengthInfo();
-      assert (readUncompressLengthInfo.contains("Client read 625 uncompressed bytes from [" + ssi + "]")
-          && readUncompressLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]")
-          && readBlokNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
+      assertTrue(readUncompressLengthInfo.contains("Client read 625 uncompressed bytes from [" + ssi + "]"));
+      assertTrue(readUncompressLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]"));
+      assertTrue(readBlockNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
     }
     
   }
diff --git a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 91253482..a588809e 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -19,7 +19,7 @@ package org.apache.uniffle.storage.factory;
 
 import java.util.ArrayList;
 import java.util.List;
-import java.util.concurrent.Callable;
+import java.util.function.Supplier;
 
 import com.google.common.collect.Lists;
 import org.apache.commons.collections.CollectionUtils;
@@ -93,7 +93,7 @@ public class ShuffleHandlerFactory {
       return getLocalfileClientReaderHandler(request, serverInfo);
     }
 
-    List<Callable<ClientReadHandler>> handlers = new ArrayList<>();
+    List<Supplier<ClientReadHandler>> handlers = new ArrayList<>();
     if (StorageType.withMemory(type)) {
       handlers.add(
           () -> getMemoryClientReadHandler(request, serverInfo)
diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
index f3baeebe..1f00edbb 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
@@ -17,10 +17,15 @@
 
 package org.apache.uniffle.storage.handler.impl;
 
+import java.util.EnumMap;
 import java.util.List;
-import java.util.concurrent.Callable;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
+import java.util.function.Supplier;
 
 import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -40,101 +45,67 @@ public class ComposedClientReadHandler extends AbstractClientReadHandler {
 
   private static final Logger LOG = LoggerFactory.getLogger(ComposedClientReadHandler.class);
 
+  private enum Tier {
+    HOT, WARM, COLD, FROZEN;
+
+    static final Tier[] VALUES = Tier.values();
+
+    Tier next() {
+      return VALUES[this.ordinal() + 1];
+    }
+  }
+
   private final ShuffleServerInfo serverInfo;
-  private Callable<ClientReadHandler> hotHandlerCreator;
-  private Callable<ClientReadHandler> warmHandlerCreator;
-  private Callable<ClientReadHandler> coldHandlerCreator;
-  private Callable<ClientReadHandler> frozenHandlerCreator;
-  private ClientReadHandler hotDataReadHandler;
-  private ClientReadHandler warmDataReadHandler;
-  private ClientReadHandler coldDataReadHandler;
-  private ClientReadHandler frozenDataReadHandler;
-  private static final int HOT = 1;
-  private static final int WARM = 2;
-  private static final int COLD = 3;
-  private static final int FROZEN = 4;
-  private int currentHandler = HOT;
-  private final int topLevelOfHandler;
+  private final Map<Tier, Supplier<ClientReadHandler>> supplierMap = new EnumMap<>(Tier.class);
+  private final Map<Tier, ClientReadHandler> handlerMap = new EnumMap<>(Tier.class);
+  private final Map<Tier, ClientReadHandlerMetric> metricsMap = new EnumMap<>(Tier.class);
+  private Tier currentTier = Tier.VALUES[0]; // == Tier.HOT
+  private final int numTiers;
 
-  private ClientReadHandlerMetric hostHandlerMetric = new ClientReadHandlerMetric();
-  private ClientReadHandlerMetric warmHandlerMetric = new ClientReadHandlerMetric();
-  private ClientReadHandlerMetric coldHandlerMetric = new ClientReadHandlerMetric();
-  private ClientReadHandlerMetric frozenHandlerMetric = new ClientReadHandlerMetric();
+  {
+    for (Tier tier : Tier.VALUES) {
+      metricsMap.put(tier, new ClientReadHandlerMetric());
+    }
+  }
 
   public ComposedClientReadHandler(ShuffleServerInfo serverInfo, ClientReadHandler... handlers) {
+    Preconditions.checkArgument(handlers.length <= Tier.VALUES.length,
+        "Too many handlers, got %d, max %d", handlers.length, Tier.VALUES.length);
     this.serverInfo = serverInfo;
-    topLevelOfHandler = handlers.length;
-    if (topLevelOfHandler > 0) {
-      this.hotDataReadHandler = handlers[0];
-    }
-    if (topLevelOfHandler > 1) {
-      this.warmDataReadHandler = handlers[1];
-    }
-    if (topLevelOfHandler > 2) {
-      this.coldDataReadHandler = handlers[2];
-    }
-    if (topLevelOfHandler > 3) {
-      this.frozenDataReadHandler = handlers[3];
+    numTiers = handlers.length;
+    for (int i = 0; i < numTiers; i++) {
+      handlerMap.put(Tier.VALUES[i], handlers[i]);
     }
   }
 
-  public ComposedClientReadHandler(ShuffleServerInfo serverInfo, List<Callable<ClientReadHandler>> callables) {
+  public ComposedClientReadHandler(ShuffleServerInfo serverInfo, List<Supplier<ClientReadHandler>> suppliers) {
+    Preconditions.checkArgument(suppliers.size() <= Tier.VALUES.length,
+        "Too many suppliers, got %d, max %d", suppliers.size(), Tier.VALUES.length);
     this.serverInfo = serverInfo;
-    topLevelOfHandler = callables.size();
-    if (topLevelOfHandler > 0) {
-      this.hotHandlerCreator = callables.get(0);
-    }
-    if (topLevelOfHandler > 1) {
-      this.warmHandlerCreator = callables.get(1);
-    }
-    if (topLevelOfHandler > 2) {
-      this.coldHandlerCreator = callables.get(2);
-    }
-    if (topLevelOfHandler > 3) {
-      this.frozenHandlerCreator = callables.get(3);
+    numTiers = suppliers.size();
+    for (int i = 0; i < numTiers; i++) {
+      supplierMap.put(Tier.VALUES[i], suppliers.get(i));
     }
   }
 
   @Override
   public ShuffleDataResult readShuffleData() {
-    ShuffleDataResult shuffleDataResult = null;
+    ClientReadHandler handler = handlerMap.computeIfAbsent(currentTier,
+        key -> supplierMap.getOrDefault(key, () -> null).get());
+    if (handler == null) {
+      throw new RssException("Unexpected null when getting " + currentTier.name() + " handler");
+    }
+    ShuffleDataResult shuffleDataResult;
     try {
-      switch (currentHandler) {
-        case HOT:
-          if (hotDataReadHandler == null) {
-            hotDataReadHandler = hotHandlerCreator.call();
-          }
-          shuffleDataResult = hotDataReadHandler.readShuffleData();
-          break;
-        case WARM:
-          if (warmDataReadHandler == null) {
-            warmDataReadHandler = warmHandlerCreator.call();
-          }
-          shuffleDataResult = warmDataReadHandler.readShuffleData();
-          break;
-        case COLD:
-          if (coldDataReadHandler == null) {
-            coldDataReadHandler = coldHandlerCreator.call();
-          }
-          shuffleDataResult = coldDataReadHandler.readShuffleData();
-          break;
-        case FROZEN:
-          if (frozenDataReadHandler == null) {
-            frozenDataReadHandler = frozenHandlerCreator.call();
-          }
-          shuffleDataResult = frozenDataReadHandler.readShuffleData();
-          break;
-        default:
-          return null;
-      }
+      shuffleDataResult = handler.readShuffleData();
     } catch (Exception e) {
-      throw new RssException("Failed to read shuffle data from " + getCurrentHandlerName() + " handler", e);
+      throw new RssException("Failed to read shuffle data from " + currentTier.name() + " handler", e);
     }
     // when is no data for current handler, and the upmostLevel is not reached,
     // then try next one if there has
     if (shuffleDataResult == null || shuffleDataResult.isEmpty()) {
-      if (currentHandler < topLevelOfHandler) {
-        currentHandler++;
+      if (currentTier.ordinal() + 1 < numTiers) {
+        currentTier = currentTier.next();
       } else {
         return null;
       }
@@ -144,44 +115,9 @@ public class ComposedClientReadHandler extends AbstractClientReadHandler {
     return shuffleDataResult;
   }
 
-  private String getCurrentHandlerName() {
-    String name = "UNKNOWN";
-    switch (currentHandler) {
-      case HOT:
-        name = "HOT";
-        break;
-      case WARM:
-        name = "WARM";
-        break;
-      case COLD:
-        name = "COLD";
-        break;
-      case FROZEN:
-        name = "FROZEN";
-        break;
-      default:
-        break;
-    }
-    return name;
-  }
-
   @Override
   public void close() {
-    if (hotDataReadHandler != null) {
-      hotDataReadHandler.close();
-    }
-
-    if (warmDataReadHandler != null) {
-      warmDataReadHandler.close();
-    }
-
-    if (coldDataReadHandler != null) {
-      coldDataReadHandler.close();
-    }
-
-    if (frozenDataReadHandler != null) {
-      frozenDataReadHandler.close();
-    }
+    handlerMap.values().stream().filter(Objects::nonNull).forEach(ClientReadHandler::close);
   }
 
   @Override
@@ -190,71 +126,49 @@ public class ComposedClientReadHandler extends AbstractClientReadHandler {
       return;
     }
     super.updateConsumedBlockInfo(bs, isSkippedMetrics);
-    switch (currentHandler) {
-      case HOT:
-        updateBlockMetric(hostHandlerMetric, bs, isSkippedMetrics);
-        break;
-      case WARM:
-        updateBlockMetric(warmHandlerMetric, bs, isSkippedMetrics);
-        break;
-      case COLD:
-        updateBlockMetric(coldHandlerMetric, bs, isSkippedMetrics);
-        break;
-      case FROZEN:
-        updateBlockMetric(frozenHandlerMetric, bs, isSkippedMetrics);
-        break;
-      default:
-        break;
-    }
+    updateBlockMetric(metricsMap.get(currentTier), bs, isSkippedMetrics);
   }
 
   @Override
   public void logConsumedBlockInfo() {
-    LOG.info(getReadBlokNumInfo());
+    LOG.info(getReadBlockNumInfo());
     LOG.info(getReadLengthInfo());
     LOG.info(getReadUncompressLengthInfo());
   }
 
   @VisibleForTesting
-  public String getReadBlokNumInfo() {
-    return "Client read " + readHandlerMetric.getReadBlockNum()
-        + " blocks from [" + serverInfo + "], Consumed["
-        + " hot:" + hostHandlerMetric.getReadBlockNum()
-        + " warm:" + warmHandlerMetric.getReadBlockNum()
-        + " cold:" + coldHandlerMetric.getReadBlockNum()
-        + " frozen:" + frozenHandlerMetric.getReadBlockNum()
-        + " ], Skipped[" + " hot:" + hostHandlerMetric.getSkippedReadBlockNum()
-        + " warm:" + warmHandlerMetric.getSkippedReadBlockNum()
-        + " cold:" + coldHandlerMetric.getSkippedReadBlockNum()
-        + " frozen:" + frozenHandlerMetric.getSkippedReadBlockNum() + " ]";
+  public String getReadBlockNumInfo() {
+    return getMetricsInfo("blocks", ClientReadHandlerMetric::getReadBlockNum,
+        ClientReadHandlerMetric::getSkippedReadBlockNum);
   }
 
   @VisibleForTesting
   public String getReadLengthInfo() {
-    return "Client read " + readHandlerMetric.getReadLength()
-        + " bytes from [" + serverInfo + "], Consumed["
-        + " hot:" + hostHandlerMetric.getReadLength()
-        + " warm:" + warmHandlerMetric.getReadLength()
-        + " cold:" + coldHandlerMetric.getReadLength()
-        + " frozen:" + frozenHandlerMetric.getReadLength() + " ], Skipped["
-        + " hot:" + hostHandlerMetric.getSkippedReadLength()
-        + " warm:" + warmHandlerMetric.getSkippedReadLength()
-        + " cold:" + coldHandlerMetric.getSkippedReadLength()
-        + " frozen:" + frozenHandlerMetric.getSkippedReadLength() + " ]";
+    return getMetricsInfo("bytes", ClientReadHandlerMetric::getReadLength,
+        ClientReadHandlerMetric::getSkippedReadLength);
   }
 
   @VisibleForTesting
   public String getReadUncompressLengthInfo() {
-    return "Client read " + readHandlerMetric.getReadUncompressLength()
-        + " uncompressed bytes from [" + serverInfo + "], Consumed["
-        + " hot:" + hostHandlerMetric.getReadUncompressLength()
-        + " warm:" + warmHandlerMetric.getReadUncompressLength()
-        + " cold:" + coldHandlerMetric.getReadUncompressLength()
-        + " frozen:" + frozenHandlerMetric.getReadUncompressLength() + " ], Skipped["
-        + " hot:" + hostHandlerMetric.getSkippedReadUncompressLength()
-        + " warm:" + warmHandlerMetric.getSkippedReadUncompressLength()
-        + " cold:" + coldHandlerMetric.getSkippedReadUncompressLength()
-        + " frozen:" + frozenHandlerMetric.getSkippedReadUncompressLength() + " ]";
+    return getMetricsInfo("uncompressed bytes", ClientReadHandlerMetric::getReadUncompressLength,
+        ClientReadHandlerMetric::getSkippedReadUncompressLength);
+  }
+
+  private String getMetricsInfo(String name, Function<ClientReadHandlerMetric, Long> consumed,
+      Function<ClientReadHandlerMetric, Long> skipped) {
+    StringBuilder sb = new StringBuilder("Client read ").append(consumed.apply(readHandlerMetric))
+        .append(" ").append(name).append(" from [").append(serverInfo).append("], Consumed[");
+    for (Tier tier : Tier.VALUES) {
+      sb.append(" ").append(tier.name().toLowerCase()).append(":")
+          .append(consumed.apply(metricsMap.get(tier)));
+    }
+    sb.append(" ], Skipped[");
+    for (Tier tier : Tier.VALUES) {
+      sb.append(" ").append(tier.name().toLowerCase()).append(":")
+          .append(skipped.apply(metricsMap.get(tier)));
+    }
+    sb.append(" ]");
+    return sb.toString();
   }
 
 }