You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@uniffle.apache.org by xi...@apache.org on 2023/02/10 03:47:37 UTC
[incubator-uniffle] branch master updated: [#575] refactor: replace switch-case with EnumMap in ComposedClientReadHandler (#570)
This is an automated email from the ASF dual-hosted git repository.
xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 9da9bb66 [#575] refactor: replace switch-case with EnumMap in ComposedClientReadHandler (#570)
9da9bb66 is described below
commit 9da9bb66f6ece6cf0b89dc931129edda63e149c6
Author: Kaijie Chen <ck...@apache.org>
AuthorDate: Fri Feb 10 11:47:32 2023 +0800
[#575] refactor: replace switch-case with EnumMap in ComposedClientReadHandler (#570)
### What changes were proposed in this pull request?
Add enum Tier, use EnumMap to replace switch-case in ComposedClientReadHandler.
### Why are the changes needed?
Make the code more concise.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing CI.
---
.../test/ShuffleServerWithMemLocalHdfsTest.java | 41 ++--
.../storage/factory/ShuffleHandlerFactory.java | 4 +-
.../handler/impl/ComposedClientReadHandler.java | 234 +++++++--------------
3 files changed, 97 insertions(+), 182 deletions(-)
diff --git a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
index c27fe6f9..5fc74979 100644
--- a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
+++ b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithMemLocalHdfsTest.java
@@ -52,6 +52,7 @@ import org.apache.uniffle.storage.util.StorageType;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class ShuffleServerWithMemLocalHdfsTest extends ShuffleReadWriteBase {
@@ -211,31 +212,31 @@ public class ShuffleServerWithMemLocalHdfsTest extends ShuffleReadWriteBase {
assertNull(sdr);
if (checkSkippedMetrics) {
- String readBlokNumInfo = composedClientReadHandler.getReadBlokNumInfo();
- assert (readBlokNumInfo.contains("Client read 0 blocks from [" + ssi + "]")
- && readBlokNumInfo.contains("Skipped[ hot:3 warm:3 cold:2 frozen:0 ]")
- && readBlokNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
+ String readBlockNumInfo = composedClientReadHandler.getReadBlockNumInfo();
+ assertTrue(readBlockNumInfo.contains("Client read 0 blocks from [" + ssi + "]"));
+ assertTrue(readBlockNumInfo.contains("Skipped[ hot:3 warm:3 cold:2 frozen:0 ]"));
+ assertTrue(readBlockNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
String readLengthInfo = composedClientReadHandler.getReadLengthInfo();
- assert (readLengthInfo.contains("Client read 0 bytes from [" + ssi + "]")
- && readLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]")
- && readBlokNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
+ assertTrue(readLengthInfo.contains("Client read 0 bytes from [" + ssi + "]"));
+ assertTrue(readLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]"));
+ assertTrue(readBlockNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
String readUncompressLengthInfo = composedClientReadHandler.getReadUncompressLengthInfo();
- assert (readUncompressLengthInfo.contains("Client read 0 uncompressed bytes from [" + ssi + "]")
- && readUncompressLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]")
- && readBlokNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
+ assertTrue(readUncompressLengthInfo.contains("Client read 0 uncompressed bytes from [" + ssi + "]"));
+ assertTrue(readUncompressLengthInfo.contains("Skipped[ hot:75 warm:150 cold:400 frozen:0 ]"));
+ assertTrue(readBlockNumInfo.contains("Consumed[ hot:0 warm:0 cold:0 frozen:0 ]"));
} else {
- String readBlokNumInfo = composedClientReadHandler.getReadBlokNumInfo();
- assert (readBlokNumInfo.contains("Client read 8 blocks from [" + ssi + "]")
- && readBlokNumInfo.contains("Consumed[ hot:3 warm:3 cold:2 frozen:0 ]")
- && readBlokNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
+ String readBlockNumInfo = composedClientReadHandler.getReadBlockNumInfo();
+ assertTrue(readBlockNumInfo.contains("Client read 8 blocks from [" + ssi + "]"));
+ assertTrue(readBlockNumInfo.contains("Consumed[ hot:3 warm:3 cold:2 frozen:0 ]"));
+ assertTrue(readBlockNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
String readLengthInfo = composedClientReadHandler.getReadLengthInfo();
- assert (readLengthInfo.contains("Client read 625 bytes from [" + ssi + "]")
- && readLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]")
- && readBlokNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
+ assertTrue(readLengthInfo.contains("Client read 625 bytes from [" + ssi + "]"));
+ assertTrue(readLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]"));
+ assertTrue(readBlockNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
String readUncompressLengthInfo = composedClientReadHandler.getReadUncompressLengthInfo();
- assert (readUncompressLengthInfo.contains("Client read 625 uncompressed bytes from [" + ssi + "]")
- && readUncompressLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]")
- && readBlokNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
+ assertTrue(readUncompressLengthInfo.contains("Client read 625 uncompressed bytes from [" + ssi + "]"));
+ assertTrue(readUncompressLengthInfo.contains("Consumed[ hot:75 warm:150 cold:400 frozen:0 ]"));
+ assertTrue(readBlockNumInfo.contains("Skipped[ hot:0 warm:0 cold:0 frozen:0 ]"));
}
}
diff --git a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 91253482..a588809e 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -19,7 +19,7 @@ package org.apache.uniffle.storage.factory;
import java.util.ArrayList;
import java.util.List;
-import java.util.concurrent.Callable;
+import java.util.function.Supplier;
import com.google.common.collect.Lists;
import org.apache.commons.collections.CollectionUtils;
@@ -93,7 +93,7 @@ public class ShuffleHandlerFactory {
return getLocalfileClientReaderHandler(request, serverInfo);
}
- List<Callable<ClientReadHandler>> handlers = new ArrayList<>();
+ List<Supplier<ClientReadHandler>> handlers = new ArrayList<>();
if (StorageType.withMemory(type)) {
handlers.add(
() -> getMemoryClientReadHandler(request, serverInfo)
diff --git a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
index f3baeebe..1f00edbb 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
@@ -17,10 +17,15 @@
package org.apache.uniffle.storage.handler.impl;
+import java.util.EnumMap;
import java.util.List;
-import java.util.concurrent.Callable;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Function;
+import java.util.function.Supplier;
import com.google.common.annotations.VisibleForTesting;
+import com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -40,101 +45,67 @@ public class ComposedClientReadHandler extends AbstractClientReadHandler {
private static final Logger LOG = LoggerFactory.getLogger(ComposedClientReadHandler.class);
+ private enum Tier {
+ HOT, WARM, COLD, FROZEN;
+
+ static final Tier[] VALUES = Tier.values();
+
+ Tier next() {
+ return VALUES[this.ordinal() + 1];
+ }
+ }
+
private final ShuffleServerInfo serverInfo;
- private Callable<ClientReadHandler> hotHandlerCreator;
- private Callable<ClientReadHandler> warmHandlerCreator;
- private Callable<ClientReadHandler> coldHandlerCreator;
- private Callable<ClientReadHandler> frozenHandlerCreator;
- private ClientReadHandler hotDataReadHandler;
- private ClientReadHandler warmDataReadHandler;
- private ClientReadHandler coldDataReadHandler;
- private ClientReadHandler frozenDataReadHandler;
- private static final int HOT = 1;
- private static final int WARM = 2;
- private static final int COLD = 3;
- private static final int FROZEN = 4;
- private int currentHandler = HOT;
- private final int topLevelOfHandler;
+ private final Map<Tier, Supplier<ClientReadHandler>> supplierMap = new EnumMap<>(Tier.class);
+ private final Map<Tier, ClientReadHandler> handlerMap = new EnumMap<>(Tier.class);
+ private final Map<Tier, ClientReadHandlerMetric> metricsMap = new EnumMap<>(Tier.class);
+ private Tier currentTier = Tier.VALUES[0]; // == Tier.HOT
+ private final int numTiers;
- private ClientReadHandlerMetric hostHandlerMetric = new ClientReadHandlerMetric();
- private ClientReadHandlerMetric warmHandlerMetric = new ClientReadHandlerMetric();
- private ClientReadHandlerMetric coldHandlerMetric = new ClientReadHandlerMetric();
- private ClientReadHandlerMetric frozenHandlerMetric = new ClientReadHandlerMetric();
+ {
+ for (Tier tier : Tier.VALUES) {
+ metricsMap.put(tier, new ClientReadHandlerMetric());
+ }
+ }
public ComposedClientReadHandler(ShuffleServerInfo serverInfo, ClientReadHandler... handlers) {
+ Preconditions.checkArgument(handlers.length <= Tier.VALUES.length,
+ "Too many handlers, got %d, max %d", handlers.length, Tier.VALUES.length);
this.serverInfo = serverInfo;
- topLevelOfHandler = handlers.length;
- if (topLevelOfHandler > 0) {
- this.hotDataReadHandler = handlers[0];
- }
- if (topLevelOfHandler > 1) {
- this.warmDataReadHandler = handlers[1];
- }
- if (topLevelOfHandler > 2) {
- this.coldDataReadHandler = handlers[2];
- }
- if (topLevelOfHandler > 3) {
- this.frozenDataReadHandler = handlers[3];
+ numTiers = handlers.length;
+ for (int i = 0; i < numTiers; i++) {
+ handlerMap.put(Tier.VALUES[i], handlers[i]);
}
}
- public ComposedClientReadHandler(ShuffleServerInfo serverInfo, List<Callable<ClientReadHandler>> callables) {
+ public ComposedClientReadHandler(ShuffleServerInfo serverInfo, List<Supplier<ClientReadHandler>> suppliers) {
+ Preconditions.checkArgument(suppliers.size() <= Tier.VALUES.length,
+ "Too many suppliers, got %d, max %d", suppliers.size(), Tier.VALUES.length);
this.serverInfo = serverInfo;
- topLevelOfHandler = callables.size();
- if (topLevelOfHandler > 0) {
- this.hotHandlerCreator = callables.get(0);
- }
- if (topLevelOfHandler > 1) {
- this.warmHandlerCreator = callables.get(1);
- }
- if (topLevelOfHandler > 2) {
- this.coldHandlerCreator = callables.get(2);
- }
- if (topLevelOfHandler > 3) {
- this.frozenHandlerCreator = callables.get(3);
+ numTiers = suppliers.size();
+ for (int i = 0; i < numTiers; i++) {
+ supplierMap.put(Tier.VALUES[i], suppliers.get(i));
}
}
@Override
public ShuffleDataResult readShuffleData() {
- ShuffleDataResult shuffleDataResult = null;
+ ClientReadHandler handler = handlerMap.computeIfAbsent(currentTier,
+ key -> supplierMap.getOrDefault(key, () -> null).get());
+ if (handler == null) {
+ throw new RssException("Unexpected null when getting " + currentTier.name() + " handler");
+ }
+ ShuffleDataResult shuffleDataResult;
try {
- switch (currentHandler) {
- case HOT:
- if (hotDataReadHandler == null) {
- hotDataReadHandler = hotHandlerCreator.call();
- }
- shuffleDataResult = hotDataReadHandler.readShuffleData();
- break;
- case WARM:
- if (warmDataReadHandler == null) {
- warmDataReadHandler = warmHandlerCreator.call();
- }
- shuffleDataResult = warmDataReadHandler.readShuffleData();
- break;
- case COLD:
- if (coldDataReadHandler == null) {
- coldDataReadHandler = coldHandlerCreator.call();
- }
- shuffleDataResult = coldDataReadHandler.readShuffleData();
- break;
- case FROZEN:
- if (frozenDataReadHandler == null) {
- frozenDataReadHandler = frozenHandlerCreator.call();
- }
- shuffleDataResult = frozenDataReadHandler.readShuffleData();
- break;
- default:
- return null;
- }
+ shuffleDataResult = handler.readShuffleData();
} catch (Exception e) {
- throw new RssException("Failed to read shuffle data from " + getCurrentHandlerName() + " handler", e);
+ throw new RssException("Failed to read shuffle data from " + currentTier.name() + " handler", e);
}
// when is no data for current handler, and the upmostLevel is not reached,
// then try next one if there has
if (shuffleDataResult == null || shuffleDataResult.isEmpty()) {
- if (currentHandler < topLevelOfHandler) {
- currentHandler++;
+ if (currentTier.ordinal() + 1 < numTiers) {
+ currentTier = currentTier.next();
} else {
return null;
}
@@ -144,44 +115,9 @@ public class ComposedClientReadHandler extends AbstractClientReadHandler {
return shuffleDataResult;
}
- private String getCurrentHandlerName() {
- String name = "UNKNOWN";
- switch (currentHandler) {
- case HOT:
- name = "HOT";
- break;
- case WARM:
- name = "WARM";
- break;
- case COLD:
- name = "COLD";
- break;
- case FROZEN:
- name = "FROZEN";
- break;
- default:
- break;
- }
- return name;
- }
-
@Override
public void close() {
- if (hotDataReadHandler != null) {
- hotDataReadHandler.close();
- }
-
- if (warmDataReadHandler != null) {
- warmDataReadHandler.close();
- }
-
- if (coldDataReadHandler != null) {
- coldDataReadHandler.close();
- }
-
- if (frozenDataReadHandler != null) {
- frozenDataReadHandler.close();
- }
+ handlerMap.values().stream().filter(Objects::nonNull).forEach(ClientReadHandler::close);
}
@Override
@@ -190,71 +126,49 @@ public class ComposedClientReadHandler extends AbstractClientReadHandler {
return;
}
super.updateConsumedBlockInfo(bs, isSkippedMetrics);
- switch (currentHandler) {
- case HOT:
- updateBlockMetric(hostHandlerMetric, bs, isSkippedMetrics);
- break;
- case WARM:
- updateBlockMetric(warmHandlerMetric, bs, isSkippedMetrics);
- break;
- case COLD:
- updateBlockMetric(coldHandlerMetric, bs, isSkippedMetrics);
- break;
- case FROZEN:
- updateBlockMetric(frozenHandlerMetric, bs, isSkippedMetrics);
- break;
- default:
- break;
- }
+ updateBlockMetric(metricsMap.get(currentTier), bs, isSkippedMetrics);
}
@Override
public void logConsumedBlockInfo() {
- LOG.info(getReadBlokNumInfo());
+ LOG.info(getReadBlockNumInfo());
LOG.info(getReadLengthInfo());
LOG.info(getReadUncompressLengthInfo());
}
@VisibleForTesting
- public String getReadBlokNumInfo() {
- return "Client read " + readHandlerMetric.getReadBlockNum()
- + " blocks from [" + serverInfo + "], Consumed["
- + " hot:" + hostHandlerMetric.getReadBlockNum()
- + " warm:" + warmHandlerMetric.getReadBlockNum()
- + " cold:" + coldHandlerMetric.getReadBlockNum()
- + " frozen:" + frozenHandlerMetric.getReadBlockNum()
- + " ], Skipped[" + " hot:" + hostHandlerMetric.getSkippedReadBlockNum()
- + " warm:" + warmHandlerMetric.getSkippedReadBlockNum()
- + " cold:" + coldHandlerMetric.getSkippedReadBlockNum()
- + " frozen:" + frozenHandlerMetric.getSkippedReadBlockNum() + " ]";
+ public String getReadBlockNumInfo() {
+ return getMetricsInfo("blocks", ClientReadHandlerMetric::getReadBlockNum,
+ ClientReadHandlerMetric::getSkippedReadBlockNum);
}
@VisibleForTesting
public String getReadLengthInfo() {
- return "Client read " + readHandlerMetric.getReadLength()
- + " bytes from [" + serverInfo + "], Consumed["
- + " hot:" + hostHandlerMetric.getReadLength()
- + " warm:" + warmHandlerMetric.getReadLength()
- + " cold:" + coldHandlerMetric.getReadLength()
- + " frozen:" + frozenHandlerMetric.getReadLength() + " ], Skipped["
- + " hot:" + hostHandlerMetric.getSkippedReadLength()
- + " warm:" + warmHandlerMetric.getSkippedReadLength()
- + " cold:" + coldHandlerMetric.getSkippedReadLength()
- + " frozen:" + frozenHandlerMetric.getSkippedReadLength() + " ]";
+ return getMetricsInfo("bytes", ClientReadHandlerMetric::getReadLength,
+ ClientReadHandlerMetric::getSkippedReadLength);
}
@VisibleForTesting
public String getReadUncompressLengthInfo() {
- return "Client read " + readHandlerMetric.getReadUncompressLength()
- + " uncompressed bytes from [" + serverInfo + "], Consumed["
- + " hot:" + hostHandlerMetric.getReadUncompressLength()
- + " warm:" + warmHandlerMetric.getReadUncompressLength()
- + " cold:" + coldHandlerMetric.getReadUncompressLength()
- + " frozen:" + frozenHandlerMetric.getReadUncompressLength() + " ], Skipped["
- + " hot:" + hostHandlerMetric.getSkippedReadUncompressLength()
- + " warm:" + warmHandlerMetric.getSkippedReadUncompressLength()
- + " cold:" + coldHandlerMetric.getSkippedReadUncompressLength()
- + " frozen:" + frozenHandlerMetric.getSkippedReadUncompressLength() + " ]";
+ return getMetricsInfo("uncompressed bytes", ClientReadHandlerMetric::getReadUncompressLength,
+ ClientReadHandlerMetric::getSkippedReadUncompressLength);
+ }
+
+ private String getMetricsInfo(String name, Function<ClientReadHandlerMetric, Long> consumed,
+ Function<ClientReadHandlerMetric, Long> skipped) {
+ StringBuilder sb = new StringBuilder("Client read ").append(consumed.apply(readHandlerMetric))
+ .append(" ").append(name).append(" from [").append(serverInfo).append("], Consumed[");
+ for (Tier tier : Tier.VALUES) {
+ sb.append(" ").append(tier.name().toLowerCase()).append(":")
+ .append(consumed.apply(metricsMap.get(tier)));
+ }
+ sb.append(" ], Skipped[");
+ for (Tier tier : Tier.VALUES) {
+ sb.append(" ").append(tier.name().toLowerCase()).append(":")
+ .append(skipped.apply(metricsMap.get(tier)));
+ }
+ sb.append(" ]");
+ return sb.toString();
}
}