You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@iotdb.apache.org by ja...@apache.org on 2022/08/17 13:26:12 UTC

[iotdb] 01/01: [IOTDB-4173] Fix NPE in SourceHandle

This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch IOTDB-4173
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit b7f8993519f2965cac03beef113cf632ad82658c
Author: JackieTien97 <ja...@gmail.com>
AuthorDate: Wed Aug 17 21:25:56 2022 +0800

    [IOTDB-4173] Fix NPE in SourceHandle
---
 .../mpp/execution/exchange/SharedTsBlockQueue.java |  3 +-
 .../db/mpp/execution/exchange/SinkHandle.java      | 15 ++++------
 .../db/mpp/execution/exchange/SourceHandle.java    | 34 ++++++++++++++++------
 .../iotdb/db/mpp/execution/memory/MemoryPool.java  |  9 ++++--
 .../mpp/execution/exchange/SourceHandleTest.java   | 12 ++++----
 .../iotdb/db/mpp/execution/exchange/Utils.java     |  3 +-
 .../db/mpp/execution/memory/MemoryPoolTest.java    | 24 +++++++--------
 7 files changed, 60 insertions(+), 40 deletions(-)

diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
index 6a594b4bf6..aacffe4796 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SharedTsBlockQueue.java
@@ -143,7 +143,8 @@ public class SharedTsBlockQueue {
     blockedOnMemory =
         localMemoryManager
             .getQueryPool()
-            .reserve(localFragmentInstanceId.getQueryId(), tsBlock.getRetainedSizeInBytes());
+            .reserve(localFragmentInstanceId.getQueryId(), tsBlock.getRetainedSizeInBytes())
+            .left;
     bufferRetainedSizeInBytes += tsBlock.getRetainedSizeInBytes();
     queue.add(tsBlock);
     if (!blocked.isDone()) {
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
index 6ccddaf44e..644c9985bf 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SinkHandle.java
@@ -32,6 +32,7 @@ import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
 import org.apache.iotdb.tsfile.utils.Pair;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.util.concurrent.ListenableFuture;
 import io.airlift.concurrent.SetThreadName;
 import org.apache.commons.lang3.Validate;
@@ -40,7 +41,6 @@ import org.slf4j.LoggerFactory;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
-import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -117,7 +117,8 @@ public class SinkHandle implements ISinkHandle {
     this.blocked =
         localMemoryManager
             .getQueryPool()
-            .reserve(localFragmentInstanceId.getQueryId(), DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES);
+            .reserve(localFragmentInstanceId.getQueryId(), DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES)
+            .left;
     this.bufferRetainedSizeInBytes = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
     this.currentTsBlockSize = DEFAULT_MAX_TSBLOCK_SIZE_IN_BYTES;
   }
@@ -144,24 +145,20 @@ public class SinkHandle implements ISinkHandle {
     }
     long retainedSizeInBytes = tsBlock.getRetainedSizeInBytes();
     int startSequenceId;
-    List<Long> tsBlockSizes = new ArrayList<>();
     startSequenceId = nextSequenceId;
     blocked =
         localMemoryManager
             .getQueryPool()
-            .reserve(localFragmentInstanceId.getQueryId(), retainedSizeInBytes);
+            .reserve(localFragmentInstanceId.getQueryId(), retainedSizeInBytes)
+            .left;
     bufferRetainedSizeInBytes += retainedSizeInBytes;
 
     sequenceIdToTsBlock.put(nextSequenceId, new Pair<>(tsBlock, currentTsBlockSize));
     nextSequenceId += 1;
     currentTsBlockSize = retainedSizeInBytes;
 
-    for (int i = startSequenceId; i < nextSequenceId; i++) {
-      tsBlockSizes.add(sequenceIdToTsBlock.get(i).left.getRetainedSizeInBytes());
-    }
-
     // TODO: consider merge multiple NewDataBlockEvent for less network traffic.
-    submitSendNewDataBlockEventTask(startSequenceId, tsBlockSizes);
+    submitSendNewDataBlockEventTask(startSequenceId, ImmutableList.of(retainedSizeInBytes));
   }
 
   @Override
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
index 1dc68c27a3..1e4fa305e4 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandle.java
@@ -31,6 +31,7 @@ import org.apache.iotdb.mpp.rpc.thrift.TGetDataBlockRequest;
 import org.apache.iotdb.mpp.rpc.thrift.TGetDataBlockResponse;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
+import org.apache.iotdb.tsfile.utils.Pair;
 
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.SettableFuture;
@@ -160,33 +161,35 @@ public class SourceHandle implements ISourceHandle {
       final int startSequenceId = nextSequenceId;
       int endSequenceId = nextSequenceId;
       long reservedBytes = 0L;
-      ListenableFuture<Void> future = null;
+      Pair<ListenableFuture<Void>, Boolean> pair = null;
+      long blockedSize = 0L;
       while (sequenceIdToDataBlockSize.containsKey(endSequenceId)) {
         Long bytesToReserve = sequenceIdToDataBlockSize.get(endSequenceId);
         if (bytesToReserve == null) {
           throw new IllegalStateException("Data block size is null.");
         }
-        future =
+        pair =
             localMemoryManager
                 .getQueryPool()
                 .reserve(localFragmentInstanceId.getQueryId(), bytesToReserve);
         bufferRetainedSizeInBytes += bytesToReserve;
         endSequenceId += 1;
         reservedBytes += bytesToReserve;
-        if (!future.isDone()) {
+        if (!pair.right) {
+          blockedSize = bytesToReserve;
           break;
         }
       }
 
-      if (future == null) {
+      if (pair == null) {
         // Next data block not generated yet. Do nothing.
         return;
       }
-
       nextSequenceId = endSequenceId;
-      executorService.submit(new GetDataBlocksTask(startSequenceId, endSequenceId, reservedBytes));
-      if (!future.isDone()) {
-        blockedOnMemory = future;
+
+      if (!pair.right) {
+        endSequenceId--;
+        reservedBytes -= blockedSize;
         // The future being not completed indicates,
         //   1. Memory has been reserved for blocks in [startSequenceId, endSequenceId).
         //   2. Memory reservation for block whose sequence ID equals endSequenceId - 1 is blocked.
@@ -196,7 +199,20 @@ public class SourceHandle implements ISourceHandle {
         //         |-------- reserved --------|--- blocked ---|--- not reserved ---|
 
         // Schedule another call of trySubmitGetDataBlocksTask for the rest of blocks.
-        future.addListener(SourceHandle.this::trySubmitGetDataBlocksTask, executorService);
+        blockedOnMemory = pair.left;
+        final int blockedSequenceId = endSequenceId;
+        final long blockedRetainedSize = blockedSize;
+        blockedOnMemory.addListener(
+            () ->
+                executorService.submit(
+                    new GetDataBlocksTask(
+                        blockedSequenceId, blockedSequenceId + 1, blockedRetainedSize)),
+            executorService);
+      }
+
+      if (endSequenceId > startSequenceId) {
+        executorService.submit(
+            new GetDataBlocksTask(startSequenceId, endSequenceId, reservedBytes));
       }
     }
   }
diff --git a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 7b24d32c13..08a5482138 100644
--- a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++ b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -19,6 +19,8 @@
 
 package org.apache.iotdb.db.mpp.execution.memory;
 
+import org.apache.iotdb.tsfile.utils.Pair;
+
 import com.google.common.util.concurrent.AbstractFuture;
 import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
@@ -89,7 +91,8 @@ public class MemoryPool {
     return maxBytes;
   }
 
-  public ListenableFuture<Void> reserve(String queryId, long bytes) {
+  /** @return if reserve succeed, pair.right will be true, otherwise false */
+  public Pair<ListenableFuture<Void>, Boolean> reserve(String queryId, long bytes) {
     Validate.notNull(queryId);
     Validate.isTrue(
         bytes > 0L && bytes <= maxBytesPerQuery,
@@ -101,14 +104,14 @@ public class MemoryPool {
           || maxBytesPerQuery - queryMemoryReservations.getOrDefault(queryId, 0L) < bytes) {
         result = MemoryReservationFuture.create(queryId, bytes);
         memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
+        return new Pair<>(result, Boolean.FALSE);
       } else {
         reservedBytes += bytes;
         queryMemoryReservations.merge(queryId, bytes, Long::sum);
         result = Futures.immediateFuture(null);
+        return new Pair<>(result, Boolean.TRUE);
       }
     }
-
-    return result;
   }
 
   public boolean tryReserve(String queryId, long bytes) {
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
index 83e97a98a3..9aaec0a940 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/SourceHandleTest.java
@@ -237,14 +237,14 @@ public class SourceHandleTest {
                   req ->
                       remoteFragmentInstanceId.equals(req.getSourceFragmentInstanceId())
                           && 0 == req.getStartSequenceId()
-                          && 6 == req.getEndSequenceId()));
+                          && 5 == req.getEndSequenceId()));
       Mockito.verify(mockClient, Mockito.times(1))
           .onAcknowledgeDataBlockEvent(
               Mockito.argThat(
                   e ->
                       remoteFragmentInstanceId.equals(e.getSourceFragmentInstanceId())
                           && 0 == e.getStartSequenceId()
-                          && 6 == e.getEndSequenceId()));
+                          && 5 == e.getEndSequenceId()));
     } catch (InterruptedException | TException e) {
       e.printStackTrace();
       Assert.fail();
@@ -260,9 +260,11 @@ public class SourceHandleTest {
       sourceHandle.receive();
       try {
         Thread.sleep(100L);
-        if (i < 4) {
-          Assert.assertEquals(6 * mockTsBlockSize, sourceHandle.getBufferRetainedSizeInBytes());
-          final int startSequenceId = 6 + i;
+        if (i < 5) {
+          Assert.assertEquals(
+              i == 4 ? 5 * mockTsBlockSize : 6 * mockTsBlockSize,
+              sourceHandle.getBufferRetainedSizeInBytes());
+          final int startSequenceId = 5 + i;
           Mockito.verify(mockClient, Mockito.times(1))
               .getDataBlock(
                   Mockito.argThat(
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
index 165f0e65a9..c72b2c26d4 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
@@ -22,6 +22,7 @@ package org.apache.iotdb.db.mpp.execution.exchange;
 import org.apache.iotdb.db.mpp.execution.memory.MemoryPool;
 import org.apache.iotdb.tsfile.read.common.block.TsBlock;
 import org.apache.iotdb.tsfile.read.common.block.column.TsBlockSerde;
+import org.apache.iotdb.tsfile.utils.Pair;
 
 import com.google.common.util.concurrent.SettableFuture;
 import org.mockito.Mockito;
@@ -100,7 +101,7 @@ public class Utils {
   public static MemoryPool createMockNonBlockedMemoryPool() {
     MemoryPool mockMemoryPool = Mockito.mock(MemoryPool.class);
     Mockito.when(mockMemoryPool.reserve(Mockito.anyString(), Mockito.anyLong()))
-        .thenReturn(immediateFuture(null));
+        .thenReturn(new Pair<>(immediateFuture(null), true));
     Mockito.when(mockMemoryPool.tryReserve(Mockito.anyString(), Mockito.anyLong()))
         .thenReturn(true);
     return mockMemoryPool;
diff --git a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
index 7fa8def8f8..3d77d651b9 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
@@ -83,7 +83,7 @@ public class MemoryPoolTest {
   @Test
   public void testReserve() {
     String queryId = "q0";
-    ListenableFuture<Void> future = pool.reserve(queryId, 256L);
+    ListenableFuture<Void> future = pool.reserve(queryId, 256L).left;
     Assert.assertTrue(future.isDone());
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
     Assert.assertEquals(256L, pool.getReservedBytes());
@@ -112,7 +112,7 @@ public class MemoryPoolTest {
   @Test
   public void testReserveAll() {
     String queryId = "q0";
-    ListenableFuture<Void> future = pool.reserve(queryId, 512L);
+    ListenableFuture<Void> future = pool.reserve(queryId, 512L).left;
     Assert.assertTrue(future.isDone());
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
     Assert.assertEquals(512L, pool.getReservedBytes());
@@ -121,11 +121,11 @@ public class MemoryPoolTest {
   @Test
   public void testOverReserve() {
     String queryId = "q0";
-    ListenableFuture<Void> future = pool.reserve(queryId, 256L);
+    ListenableFuture<Void> future = pool.reserve(queryId, 256L).left;
     Assert.assertTrue(future.isDone());
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
     Assert.assertEquals(256L, pool.getReservedBytes());
-    future = pool.reserve(queryId, 512L);
+    future = pool.reserve(queryId, 512L).left;
     Assert.assertFalse(future.isDone());
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
     Assert.assertEquals(256L, pool.getReservedBytes());
@@ -134,8 +134,8 @@ public class MemoryPoolTest {
   @Test
   public void testReserveAndFree() {
     String queryId = "q0";
-    Assert.assertTrue(pool.reserve(queryId, 512L).isDone());
-    ListenableFuture<Void> future = pool.reserve(queryId, 512L);
+    Assert.assertTrue(pool.reserve(queryId, 512L).left.isDone());
+    ListenableFuture<Void> future = pool.reserve(queryId, 512L).left;
     Assert.assertFalse(future.isDone());
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(queryId));
     Assert.assertEquals(512L, pool.getReservedBytes());
@@ -148,13 +148,13 @@ public class MemoryPoolTest {
   @Test
   public void testMultiReserveAndFree() {
     String queryId = "q0";
-    Assert.assertTrue(pool.reserve(queryId, 256L).isDone());
+    Assert.assertTrue(pool.reserve(queryId, 256L).left.isDone());
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(queryId));
     Assert.assertEquals(256L, pool.getReservedBytes());
 
-    ListenableFuture<Void> future1 = pool.reserve(queryId, 512L);
-    ListenableFuture<Void> future2 = pool.reserve(queryId, 512L);
-    ListenableFuture<Void> future3 = pool.reserve(queryId, 512L);
+    ListenableFuture<Void> future1 = pool.reserve(queryId, 512L).left;
+    ListenableFuture<Void> future2 = pool.reserve(queryId, 512L).left;
+    ListenableFuture<Void> future3 = pool.reserve(queryId, 512L).left;
     Assert.assertFalse(future1.isDone());
     Assert.assertFalse(future2.isDone());
     Assert.assertFalse(future3.isDone());
@@ -254,7 +254,7 @@ public class MemoryPoolTest {
     // Run out of memory.
     Assert.assertTrue(pool.tryReserve(queryId, 512L));
 
-    ListenableFuture<Void> f = pool.reserve(queryId, 256L);
+    ListenableFuture<Void> f = pool.reserve(queryId, 256L).left;
     Assert.assertFalse(f.isDone());
     // Cancel the reservation.
     Assert.assertEquals(256L, pool.tryCancel(f));
@@ -265,7 +265,7 @@ public class MemoryPoolTest {
   @Test
   public void testTryCancelCompletedReservation() {
     String queryId = "q0";
-    ListenableFuture<Void> f = pool.reserve(queryId, 256L);
+    ListenableFuture<Void> f = pool.reserve(queryId, 256L).left;
     Assert.assertTrue(f.isDone());
     // Cancel the reservation.
     Assert.assertEquals(0L, pool.tryCancel(f));