You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ratis.apache.org by ru...@apache.org on 2020/11/21 02:16:07 UTC
[incubator-ratis] branch master updated: RATIS-1168. Move the utils
from DataStreamBaseTest to a new file. (#292)
This is an automated email from the ASF dual-hosted git repository.
runzhiwang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-ratis.git
The following commit(s) were added to refs/heads/master by this push:
new 507f09d RATIS-1168. Move the utils from DataStreamBaseTest to a new file. (#292)
507f09d is described below
commit 507f09dfabbfd9eb25852d1bd24fbff9377909bc
Author: Tsz-Wo Nicholas Sze <sz...@apache.org>
AuthorDate: Sat Nov 21 10:15:59 2020 +0800
RATIS-1168. Move the utils from DataStreamBaseTest to a new file. (#292)
---
.../ratis/datastream/DataStreamBaseTest.java | 251 ++++-----------------
.../ratis/datastream/DataStreamClusterTests.java | 6 +-
.../ratis/datastream/DataStreamTestUtils.java | 245 ++++++++++++++++++++
.../ratis/datastream/TestDataStreamNetty.java | 4 +-
4 files changed, 295 insertions(+), 211 deletions(-)
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
index 17f675a..113fcb6 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamBaseTest.java
@@ -23,18 +23,25 @@ import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.impl.ClientProtoUtils;
import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
import org.apache.ratis.conf.RaftProperties;
+import org.apache.ratis.datastream.DataStreamTestUtils.DataChannel;
+import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
+import org.apache.ratis.datastream.DataStreamTestUtils.SingleDataStream;
import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
-import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
-import org.apache.ratis.proto.RaftProtos.*;
-import org.apache.ratis.protocol.ClientInvocationId;
+import org.apache.ratis.proto.RaftProtos.AppendEntriesReplyProto;
+import org.apache.ratis.proto.RaftProtos.AppendEntriesRequestProto;
+import org.apache.ratis.proto.RaftProtos.InstallSnapshotReplyProto;
+import org.apache.ratis.proto.RaftProtos.InstallSnapshotRequestProto;
+import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
+import org.apache.ratis.proto.RaftProtos.RequestVoteReplyProto;
+import org.apache.ratis.proto.RaftProtos.RequestVoteRequestProto;
import org.apache.ratis.protocol.ClientId;
+import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.DataStreamReply;
import org.apache.ratis.protocol.GroupInfoReply;
import org.apache.ratis.protocol.GroupInfoRequest;
import org.apache.ratis.protocol.GroupListReply;
import org.apache.ratis.protocol.GroupListRequest;
import org.apache.ratis.protocol.GroupManagementRequest;
-import org.apache.ratis.protocol.RaftClientMessage;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.protocol.RaftGroup;
@@ -42,27 +49,20 @@ import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.protocol.RaftGroupMemberId;
import org.apache.ratis.protocol.RaftPeer;
import org.apache.ratis.protocol.RaftPeerId;
-import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
import org.apache.ratis.protocol.SetConfigurationRequest;
-import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
import org.apache.ratis.rpc.RpcType;
import org.apache.ratis.server.DataStreamMap;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.impl.DataStreamServerImpl;
import org.apache.ratis.server.impl.RaftServerTestUtil;
import org.apache.ratis.server.impl.ServerFactory;
+import org.apache.ratis.statemachine.StateMachine;
import org.apache.ratis.statemachine.StateMachine.StateMachineDataChannel;
-import org.apache.ratis.statemachine.impl.BaseStateMachine;
-import org.apache.ratis.statemachine.StateMachine.DataStream;
-import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
-import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.LifeCycle;
import org.apache.ratis.util.NetUtils;
import org.junit.Assert;
-import org.slf4j.Logger;
import java.io.IOException;
-import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@@ -73,108 +73,9 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
-import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;
abstract class DataStreamBaseTest extends BaseTest {
- static final int MODULUS = 23;
-
- static byte pos2byte(int pos) {
- return (byte) ('A' + pos%MODULUS);
- }
-
- static final ByteString MOCK = ByteString.copyFromUtf8("mock");
-
- static ByteString bytesWritten2ByteString(long bytesWritten) {
- return ByteString.copyFromUtf8("bytesWritten=" + bytesWritten);
- }
-
- private final Executor executor = Executors.newFixedThreadPool(16);
-
- static class MultiDataStreamStateMachine extends BaseStateMachine {
- private final ConcurrentMap<ClientInvocationId, SingleDataStream> streams = new ConcurrentHashMap<>();
-
- @Override
- public CompletableFuture<DataStream> stream(RaftClientRequest request) {
- final SingleDataStream s = new SingleDataStream(request);
- streams.put(ClientInvocationId.valueOf(request), s);
- return CompletableFuture.completedFuture(s);
- }
-
- SingleDataStream getSingleDataStream(RaftClientRequest request) {
- return streams.get(ClientInvocationId.valueOf(request));
- }
- }
-
- static class SingleDataStream implements DataStream {
- private int byteWritten = 0;
- private final RaftClientRequest writeRequest;
- private int forcedPosition = 0;
-
- final StateMachineDataChannel channel = new StateMachineDataChannel() {
- @Override
- public void force(boolean metadata) {
- forcedPosition = byteWritten;
- }
-
- private volatile boolean open = true;
-
- @Override
- public int write(ByteBuffer src) {
- if (!open) {
- throw new IllegalStateException("Already closed");
- }
- final int remaining = src.remaining();
- for(; src.remaining() > 0; ) {
- Assert.assertEquals(pos2byte(byteWritten), src.get());
- byteWritten += 1;
- }
- return remaining;
- }
-
- @Override
- public boolean isOpen() {
- return open;
- }
-
- @Override
- public void close() {
- open = false;
- }
- };
-
- @Override
- public StateMachineDataChannel getWritableByteChannel() {
- return channel;
- }
-
- @Override
- public CompletableFuture<?> cleanUp() {
- try {
- channel.close();
- } catch (Throwable t) {
- return JavaUtils.completeExceptionally(t);
- }
- return CompletableFuture.completedFuture(null);
- }
-
- SingleDataStream(RaftClientRequest request) {
- this.writeRequest = request;
- }
-
- public int getByteWritten() {
- return byteWritten;
- }
-
- public RaftClientRequest getWriteRequest() {
- return writeRequest;
- }
-
- public int getForcedPosition() {
- return forcedPosition;
- }
- }
-
static class MyDivision implements RaftServer.Division {
private final MultiDataStreamStateMachine stateMachine = new MultiDataStreamStateMachine();
private final DataStreamMap streamMap;
@@ -240,18 +141,29 @@ abstract class DataStreamBaseTest extends BaseTest {
private List<Server> servers;
private RaftGroup raftGroup;
+ private final Executor executor = Executors.newFixedThreadPool(16);
Server getPrimaryServer() {
return servers.get(0);
}
- protected RaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
- return new RaftServer() {
+ protected MyRaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
+ return new MyRaftServer(peer.getId(), properties);
+ }
+
+ static class MyRaftServer implements RaftServer {
+ private final RaftPeerId id;
+ private final RaftProperties properties;
private final ConcurrentMap<RaftGroupId, MyDivision> divisions = new ConcurrentHashMap<>();
+ MyRaftServer(RaftPeerId id, RaftProperties properties) {
+ this.id = id;
+ this.properties = properties;
+ }
+
@Override
public RaftPeerId getId() {
- return peer.getId();
+ return id;
}
@Override
@@ -301,14 +213,21 @@ abstract class DataStreamBaseTest extends BaseTest {
@Override
public CompletableFuture<RaftClientReply> submitClientRequestAsync(RaftClientRequest request) {
- final MultiDataStreamStateMachine stateMachine = getDivision(request.getRaftGroupId()).getStateMachine();
- final SingleDataStream stream = stateMachine.getSingleDataStream(request);
- Assert.assertFalse(stream.getWritableByteChannel().isOpen());
- return CompletableFuture.completedFuture(RaftClientReply.newBuilder()
+ final MyDivision d = getDivision(request.getRaftGroupId());
+ return d.getDataStreamMap()
+ .remove(ClientInvocationId.valueOf(request))
+ .thenApply(StateMachine.DataStream::getWritableByteChannel)
+ .thenApply(channel -> buildRaftClientReply(request, channel));
+ }
+
+ static RaftClientReply buildRaftClientReply(RaftClientRequest request, StateMachineDataChannel channel) {
+ Assert.assertTrue(channel instanceof DataChannel);
+ final DataChannel dataChannel = (DataChannel) channel;
+ return RaftClientReply.newBuilder()
.setRequest(request)
.setSuccess()
- .setMessage(() -> bytesWritten2ByteString(stream.getByteWritten()))
- .build());
+ .setMessage(() -> DataStreamTestUtils.bytesWritten2ByteString(dataChannel.getBytesWritten()))
+ .build();
}
@Override
@@ -373,7 +292,6 @@ abstract class DataStreamBaseTest extends BaseTest {
public LifeCycle.State getLifeCycleState() {
return null;
}
- };
}
@@ -465,7 +383,6 @@ abstract class DataStreamBaseTest extends BaseTest {
}
}
-
void runTestMockCluster(ClientId clientId, int bufferSize, int bufferNum,
Exception expectedException, Exception headerException)
throws IOException {
@@ -491,42 +408,9 @@ abstract class DataStreamBaseTest extends BaseTest {
}
}
- static int writeAndAssertReplies(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
- final List<CompletableFuture<DataStreamReply>> futures = new ArrayList<>();
- final List<Integer> sizes = new ArrayList<>();
-
- //send data
- final int halfBufferSize = bufferSize/2;
- int dataSize = 0;
- for(int i = 0; i < bufferNum; i++) {
- final int size = halfBufferSize + ThreadLocalRandom.current().nextInt(halfBufferSize);
- sizes.add(size);
-
- final ByteBuffer bf = initBuffer(dataSize, size);
- futures.add(out.writeAsync(bf, i == bufferNum - 1));
- dataSize += size;
- }
-
- { // check header
- final DataStreamReply reply = out.getHeaderFuture().join();
- Assert.assertTrue(reply.isSuccess());
- Assert.assertEquals(0, reply.getBytesWritten());
- Assert.assertEquals(reply.getType(), Type.STREAM_HEADER);
- }
-
- // check writeAsync requests
- for(int i = 0; i < futures.size(); i++) {
- final DataStreamReply reply = futures.get(i).join();
- Assert.assertTrue(reply.isSuccess());
- Assert.assertEquals(sizes.get(i).longValue(), reply.getBytesWritten());
- Assert.assertEquals(reply.getType(), i == futures.size() - 1 ? Type.STREAM_DATA_SYNC : Type.STREAM_DATA);
- }
- return dataSize;
- }
-
CompletableFuture<RaftClientReply> runTestDataStream(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
LOG.info("start Stream{}", out.getHeader().getCallId());
- final int bytesWritten = writeAndAssertReplies(out, bufferSize, bufferNum);
+ final int bytesWritten = DataStreamTestUtils.writeAndAssertReplies(out, bufferSize, bufferNum);
try {
for (Server s : servers) {
assertHeader(s, out.getHeader(), bytesWritten);
@@ -535,32 +419,7 @@ abstract class DataStreamBaseTest extends BaseTest {
throw new CompletionException(e);
}
- return out.closeAsync().thenCompose(reply -> assertCloseReply(out, reply, bytesWritten));
- }
-
- static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
- long bytesWritten) {
- // Test close idempotent
- Assert.assertSame(dataStreamReply, out.closeAsync().join());
- testFailureCase("writeAsync should fail",
- () -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
- CompletionException.class, (Logger)null, AlreadyClosedException.class);
-
- try {
- final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
- ((DataStreamReplyByteBuffer) dataStreamReply).slice()));
- assertRaftClientMessage(out.getHeader(), reply);
- if (reply.isSuccess()) {
- final ByteString bytes = reply.getMessage().getContent();
- if (!bytes.equals(MOCK)) {
- Assert.assertEquals(bytesWritten2ByteString(bytesWritten), bytes);
- }
- }
-
- return CompletableFuture.completedFuture(reply);
- } catch (Throwable t) {
- return JavaUtils.completeExceptionally(t);
- }
+ return out.closeAsync().thenCompose(reply -> DataStreamTestUtils.assertCloseReply(out, reply, bytesWritten));
}
void assertHeader(Server server, RaftClientRequest header, int dataSize) throws Exception {
@@ -570,34 +429,14 @@ abstract class DataStreamBaseTest extends BaseTest {
// check stream
final MyDivision d = server.getDivision(header.getRaftGroupId());
- Assert.assertNotNull(d.getDataStreamMap().remove(ClientInvocationId.valueOf(header)));
final SingleDataStream stream = d.getStateMachine().getSingleDataStream(header);
- Assert.assertEquals(dataSize, stream.getByteWritten());
- Assert.assertEquals(dataSize, stream.getForcedPosition());
+ final DataChannel channel = d.getStateMachine().getSingleDataStream(header).getWritableByteChannel();
+ Assert.assertEquals(dataSize, channel.getBytesWritten());
+ Assert.assertEquals(dataSize, channel.getForcedPosition());
// check writeRequest
final RaftClientRequest writeRequest = stream.getWriteRequest();
Assert.assertEquals(RaftClientRequest.dataStreamRequestType(), writeRequest.getType());
- assertRaftClientMessage(header, writeRequest);
- }
-
- static void assertRaftClientMessage(RaftClientRequest expected, RaftClientMessage computed) {
- Assert.assertNotNull(computed);
- Assert.assertEquals(expected.getClientId(), computed.getClientId());
- Assert.assertEquals(expected.getServerId(), computed.getServerId());
- Assert.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
- Assert.assertEquals(expected.getCallId(), computed.getCallId());
- }
-
- static ByteBuffer initBuffer(int offset, int size) {
- final ByteBuffer buffer = ByteBuffer.allocateDirect(size);
- final int length = buffer.capacity();
- buffer.position(0).limit(length);
- for (int j = 0; j < length; j++) {
- buffer.put(pos2byte(offset + j));
- }
- buffer.flip();
- Assert.assertEquals(length, buffer.remaining());
- return buffer;
+ DataStreamTestUtils.assertRaftClientMessage(header, writeRequest);
}
}
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
index bf1210e..4020761 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamClusterTests.java
@@ -21,7 +21,7 @@ import org.apache.ratis.BaseTest;
import org.apache.ratis.MiniRaftCluster;
import org.apache.ratis.client.RaftClient;
import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
-import org.apache.ratis.datastream.DataStreamBaseTest.MultiDataStreamStateMachine;
+import org.apache.ratis.datastream.DataStreamTestUtils.MultiDataStreamStateMachine;
import org.apache.ratis.proto.RaftProtos.LogEntryProto;
import org.apache.ratis.proto.RaftProtos.StateMachineLogEntryProto;
import org.apache.ratis.protocol.ClientId;
@@ -65,7 +65,7 @@ public abstract class DataStreamClusterTests<CLUSTER extends MiniRaftCluster> ex
// send a stream request
try(final DataStreamOutputImpl out = (DataStreamOutputImpl) client.getDataStreamApi().stream()) {
- DataStreamBaseTest.writeAndAssertReplies(out, 1000, 10);
+ DataStreamTestUtils.writeAndAssertReplies(out, 1000, 10);
callId = out.getHeader().getCallId();
}
}
@@ -83,7 +83,7 @@ public abstract class DataStreamClusterTests<CLUSTER extends MiniRaftCluster> ex
if (entry.hasStateMachineLogEntry()) {
final StateMachineLogEntryProto stateMachineEntry = entry.getStateMachineLogEntry();
if (stateMachineEntry.getCallId() == callId) {
- if (clientId.equals(ClientId.valueOf(stateMachineEntry.getClientId()))) {
+ if (clientId.toByteString().equals(stateMachineEntry.getClientId())) {
return entry;
}
}
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
new file mode 100644
index 0000000..a834cd8
--- /dev/null
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/DataStreamTestUtils.java
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ratis.datastream;
+
+import org.apache.ratis.BaseTest;
+import org.apache.ratis.client.impl.ClientProtoUtils;
+import org.apache.ratis.client.impl.DataStreamClientImpl.DataStreamOutputImpl;
+import org.apache.ratis.datastream.impl.DataStreamReplyByteBuffer;
+import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer;
+import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type;
+import org.apache.ratis.proto.RaftProtos.LogEntryProto;
+import org.apache.ratis.proto.RaftProtos.RaftClientReplyProto;
+import org.apache.ratis.protocol.ClientInvocationId;
+import org.apache.ratis.protocol.DataStreamReply;
+import org.apache.ratis.protocol.RaftClientMessage;
+import org.apache.ratis.protocol.RaftClientReply;
+import org.apache.ratis.protocol.RaftClientRequest;
+import org.apache.ratis.protocol.exceptions.AlreadyClosedException;
+import org.apache.ratis.statemachine.StateMachine.DataStream;
+import org.apache.ratis.statemachine.StateMachine.StateMachineDataChannel;
+import org.apache.ratis.statemachine.impl.BaseStateMachine;
+import org.apache.ratis.thirdparty.com.google.protobuf.ByteString;
+import org.apache.ratis.util.JavaUtils;
+import org.junit.Assert;
+import org.slf4j.Logger;
+
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.CompletionException;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ThreadLocalRandom;
+
+public interface DataStreamTestUtils {
+ ByteString MOCK = ByteString.copyFromUtf8("mock");
+ int MODULUS = 23;
+
+ static byte pos2byte(int pos) {
+ return (byte) ('A' + pos % MODULUS);
+ }
+
+ static ByteString bytesWritten2ByteString(long bytesWritten) {
+ return ByteString.copyFromUtf8("bytesWritten=" + bytesWritten);
+ }
+
+ class MultiDataStreamStateMachine extends BaseStateMachine {
+ private final ConcurrentMap<ClientInvocationId, SingleDataStream> streams = new ConcurrentHashMap<>();
+
+ @Override
+ public CompletableFuture<DataStream> stream(RaftClientRequest request) {
+ final SingleDataStream s = new SingleDataStream(request);
+ streams.put(ClientInvocationId.valueOf(request), s);
+ return CompletableFuture.completedFuture(s);
+ }
+
+ @Override
+ public CompletableFuture<?> link(DataStream stream, LogEntryProto entry) {
+ final SingleDataStream s = getSingleDataStream(ClientInvocationId.valueOf(entry.getStateMachineLogEntry()));
+ s.setLogEntry(entry);
+ return CompletableFuture.completedFuture(null);
+ }
+
+ SingleDataStream getSingleDataStream(RaftClientRequest request) {
+ return getSingleDataStream(ClientInvocationId.valueOf(request));
+ }
+
+ SingleDataStream getSingleDataStream(ClientInvocationId invocationId) {
+ return streams.get(invocationId);
+ }
+ }
+
+ class SingleDataStream implements DataStream {
+ private final RaftClientRequest writeRequest;
+ private final DataChannel channel = new DataChannel();
+ private volatile LogEntryProto logEntry;
+
+ SingleDataStream(RaftClientRequest request) {
+ this.writeRequest = request;
+ }
+
+ @Override
+ public DataChannel getWritableByteChannel() {
+ return channel;
+ }
+
+ @Override
+ public CompletableFuture<?> cleanUp() {
+ try {
+ channel.close();
+ } catch (Throwable t) {
+ return JavaUtils.completeExceptionally(t);
+ }
+ return CompletableFuture.completedFuture(null);
+ }
+
+ void setLogEntry(LogEntryProto logEntry) {
+ this.logEntry = logEntry;
+ }
+
+ LogEntryProto getLogEntry() {
+ return logEntry;
+ }
+
+ RaftClientRequest getWriteRequest() {
+ return writeRequest;
+ }
+ }
+
+ class DataChannel implements StateMachineDataChannel {
+ private volatile boolean open = true;
+ private int bytesWritten = 0;
+ private int forcedPosition = 0;
+
+ int getBytesWritten() {
+ return bytesWritten;
+ }
+
+ int getForcedPosition() {
+ return forcedPosition;
+ }
+
+ @Override
+ public void force(boolean metadata) {
+ forcedPosition = bytesWritten;
+ }
+
+ @Override
+ public int write(ByteBuffer src) {
+ if (!open) {
+ throw new IllegalStateException("Already closed");
+ }
+ final int remaining = src.remaining();
+ for (; src.remaining() > 0; ) {
+ Assert.assertEquals(pos2byte(bytesWritten), src.get());
+ bytesWritten += 1;
+ }
+ return remaining;
+ }
+
+ @Override
+ public boolean isOpen() {
+ return open;
+ }
+
+ @Override
+ public void close() {
+ open = false;
+ }
+ }
+
+ static int writeAndAssertReplies(DataStreamOutputImpl out, int bufferSize, int bufferNum) {
+ final List<CompletableFuture<DataStreamReply>> futures = new ArrayList<>();
+ final List<Integer> sizes = new ArrayList<>();
+
+ //send data
+ final int halfBufferSize = bufferSize / 2;
+ int dataSize = 0;
+ for (int i = 0; i < bufferNum; i++) {
+ final int size = halfBufferSize + ThreadLocalRandom.current().nextInt(halfBufferSize);
+ sizes.add(size);
+
+ final ByteBuffer bf = initBuffer(dataSize, size);
+ futures.add(out.writeAsync(bf, i == bufferNum - 1));
+ dataSize += size;
+ }
+
+ { // check header
+ final DataStreamReply reply = out.getHeaderFuture().join();
+ Assert.assertTrue(reply.isSuccess());
+ Assert.assertEquals(0, reply.getBytesWritten());
+ Assert.assertEquals(reply.getType(), Type.STREAM_HEADER);
+ }
+
+ // check writeAsync requests
+ for (int i = 0; i < futures.size(); i++) {
+ final DataStreamReply reply = futures.get(i).join();
+ Assert.assertTrue(reply.isSuccess());
+ Assert.assertEquals(sizes.get(i).longValue(), reply.getBytesWritten());
+ Assert.assertEquals(reply.getType(), i == futures.size() - 1 ? Type.STREAM_DATA_SYNC : Type.STREAM_DATA);
+ }
+ return dataSize;
+ }
+
+ static CompletableFuture<RaftClientReply> assertCloseReply(DataStreamOutputImpl out, DataStreamReply dataStreamReply,
+ long bytesWritten) {
+ // Test close idempotent
+ Assert.assertSame(dataStreamReply, out.closeAsync().join());
+ BaseTest.testFailureCase("writeAsync should fail",
+ () -> out.writeAsync(DataStreamRequestByteBuffer.EMPTY_BYTE_BUFFER).join(),
+ CompletionException.class, (Logger) null, AlreadyClosedException.class);
+
+ try {
+ final RaftClientReply reply = ClientProtoUtils.toRaftClientReply(RaftClientReplyProto.parseFrom(
+ ((DataStreamReplyByteBuffer) dataStreamReply).slice()));
+ assertRaftClientMessage(out.getHeader(), reply);
+ if (reply.isSuccess()) {
+ final ByteString bytes = reply.getMessage().getContent();
+ if (!bytes.equals(MOCK)) {
+ Assert.assertEquals(bytesWritten2ByteString(bytesWritten), bytes);
+ }
+ }
+
+ return CompletableFuture.completedFuture(reply);
+ } catch (Throwable t) {
+ return JavaUtils.completeExceptionally(t);
+ }
+ }
+
+ static void assertRaftClientMessage(RaftClientMessage expected, RaftClientMessage computed) {
+ Assert.assertNotNull(computed);
+ Assert.assertEquals(expected.getClientId(), computed.getClientId());
+ Assert.assertEquals(expected.getServerId(), computed.getServerId());
+ Assert.assertEquals(expected.getRaftGroupId(), computed.getRaftGroupId());
+ Assert.assertEquals(expected.getCallId(), computed.getCallId());
+ }
+
+ static ByteBuffer initBuffer(int offset, int size) {
+ final ByteBuffer buffer = ByteBuffer.allocateDirect(size);
+ final int length = buffer.capacity();
+ buffer.position(0).limit(length);
+ for (int j = 0; j < length; j++) {
+ buffer.put(pos2byte(offset + j));
+ }
+ buffer.flip();
+ Assert.assertEquals(length, buffer.remaining());
+ return buffer;
+ }
+}
diff --git a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
index 1655a0c..e9ce41e 100644
--- a/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
+++ b/ratis-test/src/test/java/org/apache/ratis/datastream/TestDataStreamNetty.java
@@ -68,7 +68,7 @@ public class TestDataStreamNetty extends DataStreamBaseTest {
}
@Override
- protected RaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
+ protected MyRaftServer newRaftServer(RaftPeer peer, RaftProperties properties) {
final RaftProperties p = new RaftProperties(properties);
NettyConfigKeys.DataStream.setPort(p, NetUtils.createSocketAddr(peer.getDataStreamAddress()).getPort());
return super.newRaftServer(peer, p);
@@ -115,7 +115,7 @@ public class TestDataStreamNetty extends DataStreamBaseTest {
if (isLeader) {
final RaftClientReply.Builder b = RaftClientReply.newBuilder().setRequest(r);
reply = leaderException != null? b.setException(leaderException).build()
- : b.setSuccess().setMessage(() -> MOCK).build();
+ : b.setSuccess().setMessage(() -> DataStreamTestUtils.MOCK).build();
} else {
final RaftGroupMemberId memberId = RaftGroupMemberId.valueOf(peerId, groupId);
final NotLeaderException notLeaderException = new NotLeaderException(memberId, suggestedLeader, null);