You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ratis.apache.org by sz...@apache.org on 2020/11/19 06:21:03 UTC

[incubator-ratis] branch master updated: RATIS-1165. Add ClientInvocationId. (#288)

This is an automated email from the ASF dual-hosted git repository.

szetszwo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git


The following commit(s) were added to refs/heads/master by this push:
     new 0a934cf  RATIS-1165. Add ClientInvocationId. (#288)
0a934cf is described below

commit 0a934cf45f51569fb5f60a36563fdab84fd608d1
Author: Tsz-Wo Nicholas Sze <sz...@apache.org>
AuthorDate: Thu Nov 19 14:12:12 2020 +0800

    RATIS-1165. Add ClientInvocationId. (#288)
---
 .../ratis/client/impl/RaftClientTestUtil.java      |  5 +-
 .../apache/ratis/protocol/ClientInvocationId.java  | 80 ++++++++++++++++++++++
 .../org/apache/ratis/protocol/RaftClientReply.java |  5 ++
 .../ratis/server/impl/MessageStreamRequests.java   | 48 +++----------
 .../apache/ratis/server/impl/RaftServerImpl.java   | 26 +++----
 .../org/apache/ratis/server/impl/RetryCache.java   | 54 +++------------
 .../test/java/org/apache/ratis/RaftBasicTests.java |  5 +-
 .../ratis/server/impl/RaftServerTestUtil.java      |  6 +-
 .../ratis/server/impl/RetryCacheTestUtil.java      | 19 ++---
 .../ratis/server/impl/TestRetryCacheMetrics.java   |  9 +--
 .../ratis/datastream/DataStreamBaseTest.java       | 34 ++-------
 11 files changed, 138 insertions(+), 153 deletions(-)

diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
index fbb62f7..6dc65e5 100644
--- a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
+++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java
@@ -19,6 +19,7 @@ package org.apache.ratis.client.impl;
 
 import org.apache.ratis.client.RaftClient;
 import org.apache.ratis.proto.RaftProtos.SlidingWindowEntry;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.Message;
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.RaftPeerId;
@@ -29,8 +30,8 @@ public interface RaftClientTestUtil {
     ((RaftClientImpl) client).getOrderedAsync().assertRequestSemaphore(expectedAvailablePermits, expectedQueueLength);
   }
 
-  static long getCallId(RaftClient client) {
-    return ((RaftClientImpl) client).getCallId();
+  static ClientInvocationId getClientInvocationId(RaftClient client) {
+    return ClientInvocationId.valueOf(client.getId(), ((RaftClientImpl)client).getCallId());
   }
 
   static RaftClientRequest newRaftClientRequest(RaftClient client, RaftPeerId server,
diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/ClientInvocationId.java b/ratis-common/src/main/java/org/apache/ratis/protocol/ClientInvocationId.java
new file mode 100644
index 0000000..679b85c
--- /dev/null
+++ b/ratis-common/src/main/java/org/apache/ratis/protocol/ClientInvocationId.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.protocol;
+
+import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
+
+import java.util.Objects;
+
+/**
+ * The id of a client invocation.
+ * A client invocation may be an RPC or a stream.
+ *
+ * This is a value-based class.
+ */
+public final class ClientInvocationId {
+  public static ClientInvocationId valueOf(ClientId clientId, long invocationId) {
+    return new ClientInvocationId(clientId, invocationId);
+  }
+
+  public static ClientInvocationId valueOf(RaftClientMessage message) {
+    return valueOf(message.getClientId(), message.getCallId());
+  }
+
+  public static ClientInvocationId valueOf(StateMachineLogEntryProto proto) {
+    return valueOf(ClientId.valueOf(proto.getClientId()), proto.getCallId());
+  }
+
+  private final ClientId clientId;
+  /** It may be a call id or a stream id. */
+  private final long longId;
+
+  private ClientInvocationId(ClientId clientId, long longId) {
+    this.clientId = clientId;
+    this.longId = longId;
+  }
+
+  public ClientId getClientId() {
+    return clientId;
+  }
+
+  public long getLongId() {
+    return longId;
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (this == obj) {
+      return true;
+    } else if (obj == null || getClass() != obj.getClass()) {
+      return false;
+    }
+    final ClientInvocationId that = (ClientInvocationId) obj;
+    return this.longId == that.longId && Objects.equals(this.clientId, that.clientId);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(clientId, longId);
+  }
+
+  @Override
+  public String toString() {
+    return longId + "@" + clientId;
+  }
+}
diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientReply.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientReply.java
index a45bd45..e61b534 100644
--- a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientReply.java
+++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientReply.java
@@ -112,6 +112,11 @@ public class RaftClientReply extends RaftClientMessage {
           .setGroupId(serverId.getGroupId());
     }
 
+    public Builder setClientInvocationId(ClientInvocationId invocationId) {
+      return setClientId(invocationId.getClientId())
+          .setCallId(invocationId.getLongId());
+    }
+
     public Builder setRequest(RaftClientRequest request) {
       return setClientId(request.getClientId())
           .setServerId(request.getServerId())
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
index adb7e45..ac81b34 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/MessageStreamRequests.java
@@ -18,7 +18,7 @@
 package org.apache.ratis.server.impl;
 
 import org.apache.ratis.proto.RaftProtos.MessageStreamRequestTypeProto;
-import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.Message;
 import org.apache.ratis.protocol.RaftClientRequest;
 import org.apache.ratis.protocol.exceptions.StreamException;
@@ -28,7 +28,6 @@ import org.apache.ratis.util.Preconditions;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
@@ -36,43 +35,12 @@ import java.util.concurrent.ConcurrentMap;
 class MessageStreamRequests {
   public static final Logger LOG = LoggerFactory.getLogger(MessageStreamRequests.class);
 
-  private static class Key {
-    private final ClientId clientId;
-    private final long streamId;
-
-    Key(ClientId clientId, long streamId) {
-      this.clientId = clientId;
-      this.streamId = streamId;
-    }
-
-    @Override
-    public boolean equals(Object obj) {
-      if (this == obj) {
-        return true;
-      } else if (obj == null || getClass() != obj.getClass()) {
-        return false;
-      }
-      final Key that = (Key) obj;
-      return this.streamId == that.streamId && this.clientId.equals(that.clientId);
-    }
-
-    @Override
-    public int hashCode() {
-      return Objects.hash(clientId, streamId);
-    }
-
-    @Override
-    public String toString() {
-      return "Stream" + streamId + "@" + clientId;
-    }
-  }
-
   private static class PendingStream {
-    private final Key key;
+    private final ClientInvocationId key;
     private long nextId = -1;
     private ByteString bytes = ByteString.EMPTY;
 
-    PendingStream(Key key) {
+    PendingStream(ClientInvocationId key) {
       this.key = key;
     }
 
@@ -94,13 +62,13 @@ class MessageStreamRequests {
   }
 
   static class StreamMap {
-    private final ConcurrentMap<Key, PendingStream> map = new ConcurrentHashMap<>();
+    private final ConcurrentMap<ClientInvocationId, PendingStream> map = new ConcurrentHashMap<>();
 
-    PendingStream computeIfAbsent(Key key) {
+    PendingStream computeIfAbsent(ClientInvocationId key) {
       return map.computeIfAbsent(key, PendingStream::new);
     }
 
-    PendingStream remove(Key key) {
+    PendingStream remove(ClientInvocationId key) {
       return map.remove(key);
     }
 
@@ -119,7 +87,7 @@ class MessageStreamRequests {
   CompletableFuture<?> streamAsync(RaftClientRequest request) {
     final MessageStreamRequestTypeProto stream = request.getType().getMessageStream();
     Preconditions.assertTrue(!stream.getEndOfRequest());
-    final Key key = new Key(request.getClientId(), stream.getStreamId());
+    final ClientInvocationId key = ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId());
     final PendingStream pending = streams.computeIfAbsent(key);
     return pending.append(stream.getMessageId(), request.getMessage());
   }
@@ -127,7 +95,7 @@ class MessageStreamRequests {
   CompletableFuture<ByteString> streamEndOfRequestAsync(RaftClientRequest request) {
     final MessageStreamRequestTypeProto stream = request.getType().getMessageStream();
     Preconditions.assertTrue(stream.getEndOfRequest());
-    final Key key = new Key(request.getClientId(), stream.getStreamId());
+    final ClientInvocationId key = ClientInvocationId.valueOf(request.getClientId(), stream.getStreamId());
 
     final PendingStream pending = streams.remove(key);
     if (pending == null) {
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
index 3bb58f3..9318a39 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java
@@ -517,10 +517,9 @@ public class RaftServerImpl implements RaftServerProtocol, RaftServerAsynchronou
         .setCommitInfos(getCommitInfos());
   }
 
-  private RaftClientReply.Builder newReplyBuilder(ClientId clientId, long callId, long logIndex) {
+  private RaftClientReply.Builder newReplyBuilder(ClientInvocationId invocationId, long logIndex) {
     return RaftClientReply.newBuilder()
-        .setClientId(clientId)
-        .setCallId(callId)
+        .setClientInvocationId(invocationId)
         .setLogIndex(logIndex)
         .setServerId(getMemberId())
         .setCommitInfos(getCommitInfos());
@@ -556,7 +555,7 @@ public class RaftServerImpl implements RaftServerProtocol, RaftServerAsynchronou
     }
     final LeaderState leaderState = role.getLeaderState().orElse(null);
     if (leaderState == null || !leaderState.isReady()) {
-      RetryCache.CacheEntry cacheEntry = retryCache.get(request.getClientId(), request.getCallId());
+      final RetryCache.CacheEntry cacheEntry = retryCache.get(ClientInvocationId.valueOf(request));
       if (cacheEntry != null && cacheEntry.isCompletedNormally()) {
         return cacheEntry.getReplyFuture();
       }
@@ -708,8 +707,7 @@ public class RaftServerImpl implements RaftServerProtocol, RaftServerAsynchronou
         replyFuture = streamAsync(request);
       } else {
         // query the retry cache
-        RetryCache.CacheQueryResult previousResult = retryCache.queryCache(
-            request.getClientId(), request.getCallId());
+        final RetryCache.CacheQueryResult previousResult = retryCache.queryCache(ClientInvocationId.valueOf(request));
         if (previousResult.isRetry()) {
           // if the previous attempt is still pending or it succeeded, return its
           // future
@@ -1499,11 +1497,9 @@ public class RaftServerImpl implements RaftServerProtocol, RaftServerAsynchronou
   private CompletableFuture<Message> replyPendingRequest(
       LogEntryProto logEntry, CompletableFuture<Message> stateMachineFuture) {
     Preconditions.assertTrue(logEntry.hasStateMachineLogEntry());
-    final StateMachineLogEntryProto smLog = logEntry.getStateMachineLogEntry();
+    final ClientInvocationId invocationId = ClientInvocationId.valueOf(logEntry.getStateMachineLogEntry());
     // update the retry cache
-    final ClientId clientId = ClientId.valueOf(smLog.getClientId());
-    final long callId = smLog.getCallId();
-    final RetryCache.CacheEntry cacheEntry = retryCache.getOrCreateEntry(clientId, callId);
+    final RetryCache.CacheEntry cacheEntry = retryCache.getOrCreateEntry(invocationId);
     if (isLeader()) {
       Preconditions.assertTrue(cacheEntry != null && !cacheEntry.isCompletedNormally(),
               "retry cache entry should be pending: %s", cacheEntry);
@@ -1514,7 +1510,7 @@ public class RaftServerImpl implements RaftServerProtocol, RaftServerAsynchronou
 
     final long logIndex = logEntry.getIndex();
     return stateMachineFuture.whenComplete((reply, exception) -> {
-      final RaftClientReply.Builder b = newReplyBuilder(clientId, callId, logIndex);
+      final RaftClientReply.Builder b = newReplyBuilder(invocationId, logIndex);
       final RaftClientReply r;
       if (exception == null) {
         r = b.setSuccess().setMessage(reply).build();
@@ -1586,12 +1582,10 @@ public class RaftServerImpl implements RaftServerProtocol, RaftServerAsynchronou
    */
   public void notifyTruncatedLogEntry(LogEntryProto logEntry) {
     if (logEntry.hasStateMachineLogEntry()) {
-      final StateMachineLogEntryProto smLog = logEntry.getStateMachineLogEntry();
-      final ClientId clientId = ClientId.valueOf(smLog.getClientId());
-      final long callId = smLog.getCallId();
-      final RetryCache.CacheEntry cacheEntry = getRetryCache().get(clientId, callId);
+      final ClientInvocationId invocationId = ClientInvocationId.valueOf(logEntry.getStateMachineLogEntry());
+      final RetryCache.CacheEntry cacheEntry = getRetryCache().get(invocationId);
       if (cacheEntry != null) {
-        cacheEntry.failWithReply(newReplyBuilder(clientId, callId, logEntry.getIndex())
+        cacheEntry.failWithReply(newReplyBuilder(invocationId, logEntry.getIndex())
             .setException(generateNotLeaderException())
             .build());
       }
diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCache.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCache.java
index 11bb11e..4cbcaf9 100644
--- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCache.java
+++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCache.java
@@ -1,4 +1,4 @@
-/**
+/*
  * Licensed to the Apache Software Foundation (ASF) under one
  * or more contributor license agreements.  See the NOTICE file
  * distributed with this work for additional information
@@ -21,7 +21,7 @@ import java.io.Closeable;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
 
-import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.RaftClientReply;
 import org.apache.ratis.thirdparty.com.google.common.annotations.VisibleForTesting;
 import org.apache.ratis.thirdparty.com.google.common.cache.Cache;
@@ -35,44 +35,12 @@ import org.slf4j.LoggerFactory;
 public class RetryCache implements Closeable {
   static final Logger LOG = LoggerFactory.getLogger(RetryCache.class);
 
-  static class CacheKey {
-    private final ClientId clientId;
-    private final long callId;
-
-    CacheKey(ClientId clientId, long callId) {
-      this.clientId = clientId;
-      this.callId = callId;
-    }
-
-    @Override
-    public int hashCode() {
-      return clientId.hashCode() ^ Long.hashCode(callId);
-    }
-
-    @Override
-    public boolean equals(Object obj) {
-      if (this == obj) {
-        return true;
-      }
-      if (obj instanceof CacheKey) {
-        CacheKey e = (CacheKey) obj;
-        return e.clientId.equals(clientId) && callId == e.callId;
-      }
-      return false;
-    }
-
-    @Override
-    public String toString() {
-      return clientId.toString() + ":" + this.callId;
-    }
-  }
-
   /**
    * CacheEntry is tracked using unique client ID and callId of the RPC request
    */
   @VisibleForTesting
   public static class CacheEntry {
-    private final CacheKey key;
+    private final ClientInvocationId key;
     private final CompletableFuture<RaftClientReply> replyFuture =
         new CompletableFuture<>();
 
@@ -85,7 +53,7 @@ public class RetryCache implements Closeable {
      */
     private volatile boolean failed = false;
 
-    CacheEntry(CacheKey key) {
+    CacheEntry(ClientInvocationId key) {
       this.key = key;
     }
 
@@ -125,7 +93,7 @@ public class RetryCache implements Closeable {
       return replyFuture;
     }
 
-    CacheKey getKey() {
+    ClientInvocationId getKey() {
       return key;
     }
   }
@@ -148,7 +116,7 @@ public class RetryCache implements Closeable {
     }
   }
 
-  private final Cache<CacheKey, CacheEntry> cache;
+  private final Cache<ClientInvocationId, CacheEntry> cache;
 
   /**
    * @param expirationTime time for an entry to expire in milliseconds
@@ -160,8 +128,7 @@ public class RetryCache implements Closeable {
         .build();
   }
 
-  CacheEntry getOrCreateEntry(ClientId clientId, long callId) {
-    final CacheKey key = new CacheKey(clientId, callId);
+  CacheEntry getOrCreateEntry(ClientInvocationId key) {
     final CacheEntry entry;
     try {
       entry = cache.get(key, () -> new CacheEntry(key));
@@ -176,8 +143,7 @@ public class RetryCache implements Closeable {
     return newEntry;
   }
 
-  CacheQueryResult queryCache(ClientId clientId, long callId) {
-    CacheKey key = new CacheKey(clientId, callId);
+  CacheQueryResult queryCache(ClientInvocationId key) {
     final CacheEntry newEntry = new CacheEntry(key);
     CacheEntry cacheEntry;
     try {
@@ -220,8 +186,8 @@ public class RetryCache implements Closeable {
   }
 
   @VisibleForTesting
-  CacheEntry get(ClientId clientId, long callId) {
-    return cache.getIfPresent(new CacheKey(clientId, callId));
+  CacheEntry get(ClientInvocationId key) {
+    return cache.getIfPresent(key);
   }
 
   @Override
diff --git a/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java b/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java
index f20ea5b..0568f73 100644
--- a/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java
+++ b/ratis-server/src/test/java/org/apache/ratis/RaftBasicTests.java
@@ -27,6 +27,7 @@ import org.apache.ratis.client.impl.RaftClientTestUtil;
 import org.apache.ratis.metrics.MetricRegistries;
 import org.apache.ratis.metrics.MetricRegistryInfo;
 import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.RaftClientReply;
 import org.apache.ratis.protocol.RaftPeerId;
 import org.apache.ratis.server.RaftServer;
@@ -428,11 +429,11 @@ public abstract class RaftBasicTests<CLUSTER extends MiniRaftCluster>
     final Timestamp startTime = Timestamp.currentTime();
     try (final RaftClient client = cluster.createClient()) {
       // Get the next callId to be used by the client
-      long callId = RaftClientTestUtil.getCallId(client);
+      final ClientInvocationId invocationId = RaftClientTestUtil.getClientInvocationId(client);
       // Create an entry corresponding to the callId and clientId
       // in each server's retry cache.
       cluster.getServerAliveStream().forEach(
-          raftServer -> RetryCacheTestUtil.getOrCreateEntry(raftServer.getRetryCache(), client.getId(), callId));
+          raftServer -> RetryCacheTestUtil.getOrCreateEntry(raftServer.getRetryCache(), invocationId));
       // Client request for the callId now waits
       // as there is already a cache entry in the server for the request.
       // Ideally the client request should timeout and the client should retry.
diff --git a/ratis-server/src/test/java/org/apache/ratis/server/impl/RaftServerTestUtil.java b/ratis-server/src/test/java/org/apache/ratis/server/impl/RaftServerTestUtil.java
index fc1bcb0..34e8112 100644
--- a/ratis-server/src/test/java/org/apache/ratis/server/impl/RaftServerTestUtil.java
+++ b/ratis-server/src/test/java/org/apache/ratis/server/impl/RaftServerTestUtil.java
@@ -21,6 +21,7 @@ import org.apache.log4j.Level;
 import org.apache.ratis.MiniRaftCluster;
 import org.apache.ratis.proto.RaftProtos.RaftPeerRole;
 import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.RaftGroupId;
 import org.apache.ratis.protocol.RaftPeer;
 import org.apache.ratis.protocol.RaftPeerId;
@@ -89,9 +90,8 @@ public class RaftServerTestUtil {
     return server.getRetryCache().size();
   }
 
-  public static RetryCache.CacheEntry getRetryEntry(RaftServerImpl server,
-      ClientId clientId, long callId) {
-    return server.getRetryCache().get(clientId, callId);
+  public static RetryCache.CacheEntry getRetryEntry(RaftServerImpl server, ClientId clientId, long callId) {
+    return server.getRetryCache().get(ClientInvocationId.valueOf(clientId, callId));
   }
 
   public static boolean isRetryCacheEntryFailed(RetryCache.CacheEntry entry) {
diff --git a/ratis-server/src/test/java/org/apache/ratis/server/impl/RetryCacheTestUtil.java b/ratis-server/src/test/java/org/apache/ratis/server/impl/RetryCacheTestUtil.java
index a8a2465..2e6875f 100644
--- a/ratis-server/src/test/java/org/apache/ratis/server/impl/RetryCacheTestUtil.java
+++ b/ratis-server/src/test/java/org/apache/ratis/server/impl/RetryCacheTestUtil.java
@@ -18,8 +18,7 @@
 package org.apache.ratis.server.impl;
 
 import org.apache.ratis.proto.RaftProtos.LogEntryProto;
-import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
-import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.util.TimeDuration;
 import org.junit.Assert;
 
@@ -32,23 +31,19 @@ public class RetryCacheTestUtil {
 
   public static void createEntry(RetryCache cache, LogEntryProto logEntry){
     if(logEntry.hasStateMachineLogEntry()) {
-      final StateMachineLogEntryProto smLogEntry = logEntry.getStateMachineLogEntry();
-      final ClientId clientId = ClientId.valueOf(smLogEntry.getClientId());
-      final long callId = smLogEntry.getCallId();
-      cache.getOrCreateEntry(clientId, callId);
+      final ClientInvocationId invocationId = ClientInvocationId.valueOf(logEntry.getStateMachineLogEntry());
+      cache.getOrCreateEntry(invocationId);
     }
   }
 
   public static void assertFailure(RetryCache cache, LogEntryProto logEntry, boolean isFailed) {
     if(logEntry.hasStateMachineLogEntry()) {
-      final StateMachineLogEntryProto smLogEntry = logEntry.getStateMachineLogEntry();
-      final ClientId clientId = ClientId.valueOf(smLogEntry.getClientId());
-      final long callId = smLogEntry.getCallId();
-      Assert.assertEquals(isFailed, cache.get(clientId, callId).isFailed());
+      final ClientInvocationId invocationId = ClientInvocationId.valueOf(logEntry.getStateMachineLogEntry());
+      Assert.assertEquals(isFailed, cache.get(invocationId).isFailed());
     }
   }
 
-  public static void getOrCreateEntry(RetryCache cache, ClientId clientId, long callId){
-    cache.getOrCreateEntry(clientId, callId);
+  public static void getOrCreateEntry(RetryCache cache, ClientInvocationId invocationId) {
+    cache.getOrCreateEntry(invocationId);
   }
 }
diff --git a/ratis-server/src/test/java/org/apache/ratis/server/impl/TestRetryCacheMetrics.java b/ratis-server/src/test/java/org/apache/ratis/server/impl/TestRetryCacheMetrics.java
index cdde7b6..208f34a 100644
--- a/ratis-server/src/test/java/org/apache/ratis/server/impl/TestRetryCacheMetrics.java
+++ b/ratis-server/src/test/java/org/apache/ratis/server/impl/TestRetryCacheMetrics.java
@@ -24,6 +24,7 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
 import org.apache.ratis.metrics.RatisMetricRegistry;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.ClientId;
 import org.apache.ratis.protocol.RaftGroupId;
 import org.apache.ratis.protocol.RaftGroupMemberId;
@@ -64,7 +65,7 @@ public class TestRetryCacheMetrics {
       checkEntryCount(0);
 
       ClientId clientId = ClientId.randomId();
-      RetryCache.CacheKey key = new RetryCache.CacheKey(clientId, 1);
+      final ClientInvocationId key = ClientInvocationId.valueOf(clientId, 1);
       RetryCache.CacheEntry entry = new RetryCache.CacheEntry(key);
 
       retryCache.refreshEntry(entry);
@@ -79,13 +80,13 @@ public class TestRetryCacheMetrics {
       checkHit(0, 1.0);
       checkMiss(0, 0.0);
 
-      ClientId clientId = ClientId.randomId();
-      retryCache.getOrCreateEntry(clientId, 2);
+      final ClientInvocationId invocationId = ClientInvocationId.valueOf(ClientId.randomId(), 2);
+      retryCache.getOrCreateEntry(invocationId);
 
       checkHit(0, 0.0);
       checkMiss(1, 1.0);
 
-      retryCache.getOrCreateEntry(clientId, 2);
+      retryCache.getOrCreateEntry(invocationId);
 
       checkHit(1, 0.5);
       checkMiss(1, 0.5);
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
index b00d1ff..0ca5377 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
@@ -26,6 +26,7 @@ import org.apache.ratis.conf.RaftProperties;
 import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
 import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
 import org.apache.ratis.proto.RaftProtos.*;
+import org.apache.ratis.protocol.ClientInvocationId;
 import org.apache.ratis.protocol.ClientId;
 import org.apache.ratis.protocol.DataStreamReply;
 import org.apache.ratis.protocol.GroupInfoReply;
@@ -63,7 +64,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
-import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.CompletionException;
 import java.util.concurrent.ConcurrentHashMap;
@@ -89,43 +89,17 @@ abstract class DataStreamBaseTest extends BaseTest {
   private final Executor executor = Executors.newFixedThreadPool(16);
 
   static class MultiDataStreamStateMachine extends BaseStateMachine {
-    static class Key {
-      private final ClientId clientId;
-      private final long callId;
-
-      Key(RaftClientRequest request) {
-        this.clientId = request.getClientId();
-        this.callId = request.getCallId();
-      }
-
-      @Override
-      public boolean equals(Object obj) {
-        if (this == obj) {
-          return true;
-        } else if (obj == null || getClass() != obj.getClass()) {
-          return false;
-        }
-        final Key that = (Key) obj;
-        return this.callId == that.callId && Objects.equals(this.clientId, that.clientId);
-      }
-
-      @Override
-      public int hashCode() {
-        return Objects.hash(clientId, callId);
-      }
-    }
-
-    private final ConcurrentMap<Key, SingleDataStream> streams = new ConcurrentHashMap<>();
+    private final ConcurrentMap<ClientInvocationId, SingleDataStream> streams = new ConcurrentHashMap<>();
 
     @Override
     public CompletableFuture<DataStream> stream(RaftClientRequest request) {
       final SingleDataStream s = new SingleDataStream(request);
-      streams.put(new Key(request), s);
+      streams.put(ClientInvocationId.valueOf(request), s);
       return CompletableFuture.completedFuture(s);
     }
 
     SingleDataStream getSingleDataStream(RaftClientRequest request) {
-      return streams.get(new Key(request));
+      return streams.get(ClientInvocationId.valueOf(request));
     }
   }