You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by cw...@apache.org on 2020/06/27 06:31:18 UTC

[druid] branch master updated: fix query memory leak (#10027)

This is an automated email from the ASF dual-hosted git repository.

cwylie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new a4c6d5f  fix query memory leak (#10027)
a4c6d5f is described below

commit a4c6d5f37e88a8152967ba68b150fcf4077092cf
Author: chenyuzhi459 <55...@qq.com>
AuthorDate: Sat Jun 27 14:30:59 2020 +0800

    fix query memory leak (#10027)
    
    * fix query memory leak
    
    * rollup ./idea
    
    * roll up .idea
    
    * clean code
    
    * optimize style
    
    * optimize cancel function
    
    * optimize style
    
    * add concurrentGroupTest test case
    
    * add test case
    
    * add unit test
    
    * fix code style
    
    * optimize cancell method use
    
    * format code
    
    * reback code
    
    * optimize cancelAll
    
    * clean code
    
    * add comment
---
 .../org/apache/druid/common/guava/GuavaUtils.java  |  39 +++
 .../apache/druid/common/guava/GuavaUtilsTest.java  |  59 ++++
 .../druid/query/ChainedExecutionQueryRunner.java   |  19 +-
 .../druid/query/GroupByMergedQueryRunner.java      |  17 +-
 .../groupby/epinephelinae/ConcurrentGrouper.java   |  15 +-
 .../epinephelinae/GroupByMergingQueryRunnerV2.java |  20 +-
 .../groupby/GroupByQueryRunnerFailureTest.java     |  37 +++
 .../client/cache/BackgroundCachePopulator.java     |   2 +
 .../client/cache/BackgroundCachePopulatorTest.java | 304 +++++++++++++++++++++
 9 files changed, 480 insertions(+), 32 deletions(-)

diff --git a/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java b/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java
index fa69f82..b345539 100644
--- a/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java
+++ b/core/src/main/java/org/apache/druid/common/guava/GuavaUtils.java
@@ -22,13 +22,18 @@ package org.apache.druid.common.guava;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Strings;
 import com.google.common.primitives.Longs;
+import org.apache.druid.java.util.common.logger.Logger;
 
 import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Future;
 
 /**
  */
 public class GuavaUtils
 {
+  private static final Logger log = new Logger(GuavaUtils.class);
 
   /**
    * To fix semantic difference of Longs.tryParse() from Long.parseLong (Longs.tryParse() returns null for '+' started
@@ -77,4 +82,38 @@ public class GuavaUtils
     }
     return arg1;
   }
+
+  /**
+   * Cancel futures manually, because sometime we can't cancel all futures in {@link com.google.common.util.concurrent.Futures.CombinedFuture}
+   * automatically. Especially when we call {@link  com.google.common.util.concurrent.Futures#allAsList(Iterable)} to create a batch of
+   * future.
+   * @param mayInterruptIfRunning {@code true} if the thread executing this
+   * task should be interrupted; otherwise, in-progress tasks are allowed
+   * to complete
+   * @param combinedFuture The combinedFuture that associated with futures
+   * @param futures The futures that we want to cancel
+   */
+  public static <F extends Future<?>> void cancelAll(
+      boolean mayInterruptIfRunning,
+      @Nullable Future<?> combinedFuture,
+      List<F> futures
+  )
+  {
+    final List<Future> allFuturesToCancel = new ArrayList<>();
+    allFuturesToCancel.add(combinedFuture);
+    allFuturesToCancel.addAll(futures);
+    if (allFuturesToCancel.isEmpty()) {
+      return;
+    }
+    allFuturesToCancel.forEach(f -> {
+      try {
+        if (f != null) {
+          f.cancel(mayInterruptIfRunning);
+        }
+      }
+      catch (Throwable t) {
+        log.warn(t, "Error while cancelling future.");
+      }
+    });
+  }
 }
diff --git a/core/src/test/java/org/apache/druid/common/guava/GuavaUtilsTest.java b/core/src/test/java/org/apache/druid/common/guava/GuavaUtilsTest.java
index 6bd764f..27bebbf 100644
--- a/core/src/test/java/org/apache/druid/common/guava/GuavaUtilsTest.java
+++ b/core/src/test/java/org/apache/druid/common/guava/GuavaUtilsTest.java
@@ -20,9 +20,22 @@
 package org.apache.druid.common.guava;
 
 import com.google.common.primitives.Longs;
+import com.google.common.util.concurrent.Futures;
+import com.google.common.util.concurrent.ListenableFuture;
+import com.google.common.util.concurrent.ListeningExecutorService;
+import com.google.common.util.concurrent.MoreExecutors;
 import org.junit.Assert;
 import org.junit.Test;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
+
 public class GuavaUtilsTest
 {
   enum MyEnum
@@ -53,4 +66,50 @@ public class GuavaUtilsTest
     Assert.assertEquals(MyEnum.BUCKLE_MY_SHOE, GuavaUtils.getEnumIfPresent(MyEnum.class, "BUCKLE_MY_SHOE"));
     Assert.assertEquals(null, GuavaUtils.getEnumIfPresent(MyEnum.class, "buckle_my_shoe"));
   }
+
+  @Test
+  public void testCancelAll()
+  {
+    int tasks = 3;
+    ExecutorService service = Executors.newFixedThreadPool(tasks);
+    ListeningExecutorService exc = MoreExecutors.listeningDecorator(service);
+    AtomicInteger index = new AtomicInteger(0);
+    //a flag what time to throw exception.
+    AtomicBoolean active = new AtomicBoolean(false);
+    Function<Integer, List<ListenableFuture<Object>>> function = (taskCount) -> {
+      List<ListenableFuture<Object>> futures = new ArrayList<>();
+      for (int i = 0; i < taskCount; i++) {
+        ListenableFuture<Object> future = exc.submit(new Callable<Object>() {
+          @Override
+          public Object call() throws RuntimeException
+          {
+            int internalIndex = index.incrementAndGet();
+            while (true) {
+              if (internalIndex == taskCount && active.get()) {
+                //here we simulate occurs exception in some one future.
+                throw new RuntimeException("A big bug");
+              }
+            }
+          }
+        });
+        futures.add(future);
+      }
+      return futures;
+    };
+
+    List<ListenableFuture<Object>> futures = function.apply(tasks);
+    Assert.assertEquals(tasks, futures.stream().filter(f -> !f.isDone()).count());
+    //here we make one of task throw exception.
+    active.set(true);
+
+    ListenableFuture<List<Object>> future = Futures.allAsList(futures);
+    try {
+      future.get();
+    }
+    catch (Exception e) {
+      Assert.assertEquals("java.lang.RuntimeException: A big bug", e.getMessage());
+      GuavaUtils.cancelAll(true, future, futures);
+      Assert.assertEquals(0, futures.stream().filter(f -> !f.isDone()).count());
+    }
+  }
 }
diff --git a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java
index cf14913..f400bac 100644
--- a/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java
+++ b/processing/src/main/java/org/apache/druid/query/ChainedExecutionQueryRunner.java
@@ -27,6 +27,7 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListeningExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
+import org.apache.druid.common.guava.GuavaUtils;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.guava.BaseSequence;
 import org.apache.druid.java.util.common.guava.MergeIterable;
@@ -100,7 +101,7 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
           public Iterator<T> make()
           {
             // Make it a List<> to materialize all of the values (so that it will submit everything to the executor)
-            ListenableFuture<List<Iterable<T>>> futures = Futures.allAsList(
+            List<ListenableFuture<Iterable<T>>> futures =
                 Lists.newArrayList(
                     Iterables.transform(
                         queryables,
@@ -141,22 +142,23 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
                           );
                         }
                     )
-                )
-            );
+                );
 
-            queryWatcher.registerQueryFuture(query, futures);
+            ListenableFuture<List<Iterable<T>>> future = Futures.allAsList(futures);
+            queryWatcher.registerQueryFuture(query, future);
 
             try {
               return new MergeIterable<>(
                   ordering.nullsFirst(),
                   QueryContexts.hasTimeout(query) ?
-                      futures.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS) :
-                      futures.get()
+                      future.get(QueryContexts.getTimeout(query), TimeUnit.MILLISECONDS) :
+                      future.get()
               ).iterator();
             }
             catch (InterruptedException e) {
               log.noStackTrace().warn(e, "Query interrupted, cancelling pending results, query id [%s]", query.getId());
-              futures.cancel(true);
+              //Note: canceling combinedFuture first so that it can complete with INTERRUPTED as its final state. See ChainedExecutionQueryRunnerTest.testQueryTimeout()
+              GuavaUtils.cancelAll(true, future, futures);
               throw new QueryInterruptedException(e);
             }
             catch (CancellationException e) {
@@ -164,10 +166,11 @@ public class ChainedExecutionQueryRunner<T> implements QueryRunner<T>
             }
             catch (TimeoutException e) {
               log.warn("Query timeout, cancelling pending results for query id [%s]", query.getId());
-              futures.cancel(true);
+              GuavaUtils.cancelAll(true, future, futures);
               throw new QueryInterruptedException(e);
             }
             catch (ExecutionException e) {
+              GuavaUtils.cancelAll(true, future, futures);
               Throwables.propagateIfPossible(e.getCause());
               throw new RuntimeException(e.getCause());
             }
diff --git a/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java b/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java
index 1653fb1..8845f2a 100644
--- a/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java
+++ b/processing/src/main/java/org/apache/druid/query/GroupByMergedQueryRunner.java
@@ -23,6 +23,7 @@ import com.google.common.base.Function;
 import com.google.common.base.Predicates;
 import com.google.common.base.Supplier;
 import com.google.common.base.Throwables;
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Lists;
 import com.google.common.util.concurrent.Futures;
@@ -30,6 +31,7 @@ import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListeningExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
 import org.apache.druid.collections.NonBlockingPool;
+import org.apache.druid.common.guava.GuavaUtils;
 import org.apache.druid.data.input.Row;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.Pair;
@@ -93,7 +95,7 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
     final boolean bySegment = QueryContexts.isBySegment(query);
     final int priority = QueryContexts.getPriority(query);
     final QueryPlus<T> threadSafeQueryPlus = queryPlus.withoutThreadUnsafeState();
-    final ListenableFuture<List<Void>> futures = Futures.allAsList(
+    final List<ListenableFuture<Void>> futures =
         Lists.newArrayList(
             Iterables.transform(
                 queryables,
@@ -136,15 +138,14 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
                     );
 
                     if (isSingleThreaded) {
-                      waitForFutureCompletion(query, future, indexAccumulatorPair.lhs);
+                      waitForFutureCompletion(query, ImmutableList.of(future), indexAccumulatorPair.lhs);
                     }
 
                     return future;
                   }
                 }
             )
-        )
-    );
+        );
 
     if (!isSingleThreaded) {
       waitForFutureCompletion(query, futures, indexAccumulatorPair.lhs);
@@ -173,10 +174,11 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
 
   private void waitForFutureCompletion(
       GroupByQuery query,
-      ListenableFuture<?> future,
+      List<ListenableFuture<Void>> futures,
       IncrementalIndex<?> closeOnFailure
   )
   {
+    ListenableFuture<List<Void>> future = Futures.allAsList(futures);
     try {
       queryWatcher.registerQueryFuture(query, future);
       if (QueryContexts.hasTimeout(query)) {
@@ -187,7 +189,7 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
     }
     catch (InterruptedException e) {
       log.warn(e, "Query interrupted, cancelling pending results, query id [%s]", query.getId());
-      future.cancel(true);
+      GuavaUtils.cancelAll(true, future, futures);
       closeOnFailure.close();
       throw new QueryInterruptedException(e);
     }
@@ -198,10 +200,11 @@ public class GroupByMergedQueryRunner<T> implements QueryRunner<T>
     catch (TimeoutException e) {
       closeOnFailure.close();
       log.info("Query timeout, cancelling pending results for query id [%s]", query.getId());
-      future.cancel(true);
+      GuavaUtils.cancelAll(true, future, futures);
       throw new QueryInterruptedException(e);
     }
     catch (ExecutionException e) {
+      GuavaUtils.cancelAll(true, future, futures);
       closeOnFailure.close();
       Throwables.propagateIfPossible(e.getCause());
       throw new RuntimeException(e.getCause());
diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/ConcurrentGrouper.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/ConcurrentGrouper.java
index f5ff7ba..73c1b64 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/ConcurrentGrouper.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/ConcurrentGrouper.java
@@ -28,6 +28,7 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListeningExecutorService;
 import org.apache.druid.collections.ReferenceCountingResourceHolder;
+import org.apache.druid.common.guava.GuavaUtils;
 import org.apache.druid.java.util.common.CloseableIterators;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.parsers.CloseableIterator;
@@ -339,8 +340,7 @@ public class ConcurrentGrouper<KeyType> implements Grouper<KeyType>
   private List<CloseableIterator<Entry<KeyType>>> parallelSortAndGetGroupersIterator()
   {
     // The number of groupers is same with the number of processing threads in the executor
-    final ListenableFuture<List<CloseableIterator<Entry<KeyType>>>> future = Futures.allAsList(
-        groupers.stream()
+    final List<ListenableFuture<CloseableIterator<Entry<KeyType>>>> futures = groupers.stream()
                 .map(grouper ->
                          executor.submit(
                              new AbstractPrioritizedCallable<CloseableIterator<Entry<KeyType>>>(priority)
@@ -353,21 +353,20 @@ public class ConcurrentGrouper<KeyType> implements Grouper<KeyType>
                              }
                          )
                 )
-                .collect(Collectors.toList())
+                .collect(Collectors.toList()
     );
 
+    ListenableFuture<List<CloseableIterator<Entry<KeyType>>>> future = Futures.allAsList(futures);
     try {
       final long timeout = queryTimeoutAt - System.currentTimeMillis();
       return hasQueryTimeout ? future.get(timeout, TimeUnit.MILLISECONDS) : future.get();
     }
-    catch (InterruptedException | TimeoutException e) {
-      future.cancel(true);
-      throw new QueryInterruptedException(e);
-    }
-    catch (CancellationException e) {
+    catch (InterruptedException | TimeoutException | CancellationException e) {
+      GuavaUtils.cancelAll(true, future, futures);
       throw new QueryInterruptedException(e);
     }
     catch (ExecutionException e) {
+      GuavaUtils.cancelAll(true, future, futures);
       throw new RuntimeException(e.getCause());
     }
   }
diff --git a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java
index 2de214c..6d0563c 100644
--- a/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java
+++ b/processing/src/main/java/org/apache/druid/query/groupby/epinephelinae/GroupByMergingQueryRunnerV2.java
@@ -34,6 +34,7 @@ import com.google.common.util.concurrent.MoreExecutors;
 import org.apache.druid.collections.BlockingPool;
 import org.apache.druid.collections.ReferenceCountingResourceHolder;
 import org.apache.druid.collections.Releaser;
+import org.apache.druid.common.guava.GuavaUtils;
 import org.apache.druid.java.util.common.ISE;
 import org.apache.druid.java.util.common.Pair;
 import org.apache.druid.java.util.common.StringUtils;
@@ -215,8 +216,7 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
                   ReferenceCountingResourceHolder.fromCloseable(grouper);
               resources.register(grouperHolder);
 
-              ListenableFuture<List<AggregateResult>> futures = Futures.allAsList(
-                  Lists.newArrayList(
+              List<ListenableFuture<AggregateResult>> futures = Lists.newArrayList(
                       Iterables.transform(
                           queryables,
                           new Function<QueryRunner<ResultRow>, ListenableFuture<AggregateResult>>()
@@ -259,7 +259,7 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
                               if (isSingleThreaded) {
                                 waitForFutureCompletion(
                                     query,
-                                    Futures.allAsList(ImmutableList.of(future)),
+                                    ImmutableList.of(future),
                                     hasTimeout,
                                     timeoutAt - System.currentTimeMillis()
                                 );
@@ -269,8 +269,7 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
                             }
                           }
                       )
-                  )
-              );
+                  );
 
               if (!isSingleThreaded) {
                 waitForFutureCompletion(query, futures, hasTimeout, timeoutAt - System.currentTimeMillis());
@@ -339,11 +338,12 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
 
   private void waitForFutureCompletion(
       GroupByQuery query,
-      ListenableFuture<List<AggregateResult>> future,
+      List<ListenableFuture<AggregateResult>> futures,
       boolean hasTimeout,
       long timeout
   )
   {
+    ListenableFuture<List<AggregateResult>> future = Futures.allAsList(futures);
     try {
       if (queryWatcher != null) {
         queryWatcher.registerQueryFuture(query, future);
@@ -357,25 +357,27 @@ public class GroupByMergingQueryRunnerV2 implements QueryRunner<ResultRow>
 
       for (AggregateResult result : results) {
         if (!result.isOk()) {
-          future.cancel(true);
+          GuavaUtils.cancelAll(true, future, futures);
           throw new ResourceLimitExceededException(result.getReason());
         }
       }
     }
     catch (InterruptedException e) {
       log.warn(e, "Query interrupted, cancelling pending results, query id [%s]", query.getId());
-      future.cancel(true);
+      GuavaUtils.cancelAll(true, future, futures);
       throw new QueryInterruptedException(e);
     }
     catch (CancellationException e) {
+      GuavaUtils.cancelAll(true, future, futures);
       throw new QueryInterruptedException(e);
     }
     catch (TimeoutException e) {
       log.info("Query timeout, cancelling pending results for query id [%s]", query.getId());
-      future.cancel(true);
+      GuavaUtils.cancelAll(true, future, futures);
       throw new QueryInterruptedException(e);
     }
     catch (ExecutionException e) {
+      GuavaUtils.cancelAll(true, future, futures);
       throw new RuntimeException(e);
     }
   }
diff --git a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerFailureTest.java b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerFailureTest.java
index 142f361..896016c 100644
--- a/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerFailureTest.java
+++ b/processing/src/test/java/org/apache/druid/query/groupby/GroupByQueryRunnerFailureTest.java
@@ -281,4 +281,41 @@ public class GroupByQueryRunnerFailureTest
       }
     }
   }
+
+  @Test(timeout = 60_000L)
+  public void testTimeoutExceptionOnQueryable()
+  {
+    expectedException.expect(QueryInterruptedException.class);
+    expectedException.expectCause(CoreMatchers.instanceOf(TimeoutException.class));
+
+    final GroupByQuery query = GroupByQuery
+        .builder()
+        .setDataSource(QueryRunnerTestHelper.DATA_SOURCE)
+        .setQuerySegmentSpec(QueryRunnerTestHelper.FIRST_TO_THIRD)
+        .setDimensions(new DefaultDimensionSpec("quality", "alias"))
+        .setAggregatorSpecs(new LongSumAggregatorFactory("rows", "rows"))
+        .setGranularity(QueryRunnerTestHelper.DAY_GRAN)
+        .overrideContext(ImmutableMap.of(QueryContexts.TIMEOUT_KEY, 1))
+        .build();
+
+    GroupByQueryRunnerFactory factory = makeQueryRunnerFactory(
+        GroupByQueryRunnerTest.DEFAULT_MAPPER,
+        new GroupByQueryConfig()
+        {
+          @Override
+          public String getDefaultStrategy()
+          {
+            return "v2";
+          }
+
+          @Override
+          public boolean isSingleThreaded()
+          {
+            return true;
+          }
+        }
+    );
+    QueryRunner<ResultRow> mergeRunners = factory.mergeRunners(Execs.directExecutor(), ImmutableList.of(runner));
+    GroupByQueryRunnerTestHelper.runQuery(factory, mergeRunners, query);
+  }
 }
diff --git a/server/src/main/java/org/apache/druid/client/cache/BackgroundCachePopulator.java b/server/src/main/java/org/apache/druid/client/cache/BackgroundCachePopulator.java
index 3b7f30d..3c67cb0 100644
--- a/server/src/main/java/org/apache/druid/client/cache/BackgroundCachePopulator.java
+++ b/server/src/main/java/org/apache/druid/client/cache/BackgroundCachePopulator.java
@@ -27,6 +27,7 @@ import com.google.common.util.concurrent.Futures;
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.ListeningExecutorService;
 import com.google.common.util.concurrent.MoreExecutors;
+import org.apache.druid.common.guava.GuavaUtils;
 import org.apache.druid.java.util.common.concurrent.Execs;
 import org.apache.druid.java.util.common.guava.Sequence;
 import org.apache.druid.java.util.common.guava.Sequences;
@@ -100,6 +101,7 @@ public class BackgroundCachePopulator implements CachePopulator
                 @Override
                 public void onFailure(Throwable t)
                 {
+                  GuavaUtils.cancelAll(true, null, cacheFutures);
                   log.error(t, "Background caching failed");
                 }
               },
diff --git a/server/src/test/java/org/apache/druid/client/cache/BackgroundCachePopulatorTest.java b/server/src/test/java/org/apache/druid/client/cache/BackgroundCachePopulatorTest.java
new file mode 100644
index 0000000..a93d69e
--- /dev/null
+++ b/server/src/test/java/org/apache/druid/client/cache/BackgroundCachePopulatorTest.java
@@ -0,0 +1,304 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.client.cache;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.common.collect.ImmutableMap;
+import org.apache.druid.client.CacheUtil;
+import org.apache.druid.client.CachingClusteredClientTestUtils;
+import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.ISE;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.concurrent.Execs;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.guava.Sequence;
+import org.apache.druid.java.util.common.guava.SequenceWrapper;
+import org.apache.druid.java.util.common.guava.Sequences;
+import org.apache.druid.java.util.emitter.service.ServiceEmitter;
+import org.apache.druid.query.CacheStrategy;
+import org.apache.druid.query.Query;
+import org.apache.druid.query.QueryPlus;
+import org.apache.druid.query.QueryRunner;
+import org.apache.druid.query.QueryToolChest;
+import org.apache.druid.query.Result;
+import org.apache.druid.query.SegmentDescriptor;
+import org.apache.druid.query.aggregation.AggregatorFactory;
+import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.LongSumAggregatorFactory;
+import org.apache.druid.query.context.ResponseContext;
+import org.apache.druid.query.topn.TopNQueryBuilder;
+import org.apache.druid.query.topn.TopNQueryConfig;
+import org.apache.druid.query.topn.TopNQueryQueryToolChest;
+import org.apache.druid.query.topn.TopNResultValue;
+import org.joda.time.DateTime;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import java.io.Closeable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class BackgroundCachePopulatorTest
+{
+  private static final ObjectMapper JSON_MAPPER = CachingClusteredClientTestUtils.createObjectMapper();
+  private static final Object[] OBJECTS = new Object[]{
+      DateTimes.of("2011-01-05"), "a", 50, 4994, "b", 50, 4993, "c", 50, 4992,
+      DateTimes.of("2011-01-06"), "a", 50, 4991, "b", 50, 4990, "c", 50, 4989,
+      DateTimes.of("2011-01-07"), "a", 50, 4991, "b", 50, 4990, "c", 50, 4989,
+      DateTimes.of("2011-01-08"), "a", 50, 4988, "b", 50, 4987, "c", 50, 4986,
+      DateTimes.of("2011-01-09"), "a", 50, 4985, "b", 50, 4984, "c", 50, 4983
+  };
+  private static final List<AggregatorFactory> AGGS = Arrays.asList(
+      new CountAggregatorFactory("rows"),
+      new LongSumAggregatorFactory("imps", "imps"),
+      new LongSumAggregatorFactory("impers", "imps")
+  );
+  private BackgroundCachePopulator backgroundCachePopulator;
+  private QueryToolChest toolchest;
+  private Cache cache;
+  private Query query;
+  private QueryRunner baseRunner;
+  private AssertingClosable closable;
+
+  @Before
+  public void before()
+  {
+    this.backgroundCachePopulator = new BackgroundCachePopulator(
+        Execs.multiThreaded(2, "CachingQueryRunnerTest-%d"),
+        JSON_MAPPER,
+        new CachePopulatorStats(),
+        -1
+    );
+
+
+    TopNQueryBuilder builder = new TopNQueryBuilder()
+        .dataSource("ds")
+        .dimension("top_dim")
+        .metric("imps")
+        .threshold(3)
+        .intervals("2011-01-05/2011-01-10")
+        .aggregators(AGGS)
+        .granularity(Granularities.ALL);
+
+    this.query = builder.build();
+    this.toolchest = new TopNQueryQueryToolChest(new TopNQueryConfig());
+    List<Result> expectedRes = makeTopNResults(false, OBJECTS);
+
+    this.closable = new AssertingClosable();
+    final Sequence resultSeq = Sequences.wrap(
+        Sequences.simple(expectedRes),
+        new SequenceWrapper()
+        {
+          @Override
+          public void before()
+          {
+            Assert.assertFalse(closable.isClosed());
+          }
+
+          @Override
+          public void after(boolean isDone, Throwable thrown)
+          {
+            closable.close();
+          }
+        }
+    );
+    this.baseRunner = (queryPlus, responseContext) -> resultSeq;
+
+    this.cache = new Cache()
+    {
+      private final ConcurrentMap<NamedKey, byte[]> baseMap = new ConcurrentHashMap<>();
+
+      @Override
+      public byte[] get(NamedKey key)
+      {
+        return baseMap.get(key);
+      }
+
+      @Override
+      public void put(NamedKey key, byte[] value)
+      {
+        baseMap.put(key, value);
+      }
+
+      @Override
+      public Map<NamedKey, byte[]> getBulk(Iterable<NamedKey> keys)
+      {
+        return null;
+      }
+
+      @Override
+      public void close(String namespace)
+      {
+      }
+
+      @Override
+      public void close()
+      {
+      }
+
+      @Override
+      public CacheStats getStats()
+      {
+        return null;
+      }
+
+      @Override
+      public boolean isLocal()
+      {
+        return true;
+      }
+
+      @Override
+      public void doMonitor(ServiceEmitter emitter)
+      {
+      }
+    };
+  }
+
+  /**
+  *
+  * Method: wrap(final Sequence<T> sequence, final Function<T, CacheType> cacheFn, final Cache cache, final Cache.NamedKey cacheKey)
+  *
+  */
+  @Test
+  public void testWrap()
+  {
+    String cacheId = "segment";
+    SegmentDescriptor segmentDescriptor = new SegmentDescriptor(Intervals.of("2011/2012"), "version", 0);
+
+
+    CacheStrategy cacheStrategy = toolchest.getCacheStrategy(query);
+    Cache.NamedKey cacheKey = CacheUtil.computeSegmentCacheKey(
+        cacheId,
+        segmentDescriptor,
+        cacheStrategy.computeCacheKey(query)
+    );
+
+    Sequence res = this.backgroundCachePopulator.wrap(this.baseRunner.run(QueryPlus.wrap(query), ResponseContext.createEmpty()),
+        (value) -> cacheStrategy.prepareForSegmentLevelCache().apply(value), cache, cacheKey);
+    Assert.assertFalse("sequence must not be closed", closable.isClosed());
+    Assert.assertNull("cache must be empty", cache.get(cacheKey));
+
+    List results = res.toList();
+    Assert.assertTrue(closable.isClosed());
+    List<Result> expectedRes = makeTopNResults(false, OBJECTS);
+    Assert.assertEquals(expectedRes.toString(), results.toString());
+    Assert.assertEquals(5, results.size());
+  }
+
+  @Test
+  public void testWrapOnFailure()
+  {
+    String cacheId = "segment";
+    SegmentDescriptor segmentDescriptor = new SegmentDescriptor(Intervals.of("2011/2012"), "version", 0);
+
+
+    CacheStrategy cacheStrategy = toolchest.getCacheStrategy(query);
+    Cache.NamedKey cacheKey = CacheUtil.computeSegmentCacheKey(
+        cacheId,
+        segmentDescriptor,
+        cacheStrategy.computeCacheKey(query)
+    );
+
+    Sequence res = this.backgroundCachePopulator.wrap(this.baseRunner.run(QueryPlus.wrap(query), ResponseContext.createEmpty()),
+        (value) -> {
+        throw new RuntimeException("Error");
+      }, cache, cacheKey);
+    Assert.assertFalse("sequence must not be closed", closable.isClosed());
+    Assert.assertNull("cache must be empty", cache.get(cacheKey));
+
+    List results = res.toList();
+    Assert.assertTrue(closable.isClosed());
+    List<Result> expectedRes = makeTopNResults(false, OBJECTS);
+    Assert.assertEquals(expectedRes.toString(), results.toString());
+    Assert.assertEquals(5, results.size());
+  }
+
+
+  private List<Result> makeTopNResults(boolean cachedResults, Object... objects)
+  {
+    List<Result> retVal = new ArrayList<>();
+    int index = 0;
+    while (index < objects.length) {
+      DateTime timestamp = (DateTime) objects[index++];
+
+      List<Map<String, Object>> values = new ArrayList<>();
+      while (index < objects.length && !(objects[index] instanceof DateTime)) {
+        if (objects.length - index < 3) {
+          throw new ISE(
+              "expect 3 values for each entry in the top list, had %d values left.", objects.length - index
+          );
+        }
+        final double imps = ((Number) objects[index + 2]).doubleValue();
+        final double rows = ((Number) objects[index + 1]).doubleValue();
+
+        if (cachedResults) {
+          values.add(
+              ImmutableMap.of(
+                  "top_dim", objects[index],
+                  "rows", rows,
+                  "imps", imps,
+                  "impers", imps
+              )
+          );
+        } else {
+          values.add(
+              ImmutableMap.of(
+                  "top_dim", objects[index],
+                  "rows", rows,
+                  "imps", imps,
+                  "impers", imps,
+                  "avg_imps_per_row", imps / rows
+              )
+          );
+        }
+        index += 3;
+      }
+
+      retVal.add(new Result<>(timestamp, new TopNResultValue(values)));
+    }
+    return retVal;
+  }
+
+  private static class AssertingClosable implements Closeable
+  {
+
+    private final AtomicBoolean closed = new AtomicBoolean(false);
+
+    @Override
+    public void close()
+    {
+      Assert.assertFalse(closed.get());
+      Assert.assertTrue(closed.compareAndSet(false, true));
+    }
+
+    public boolean isClosed()
+    {
+      return closed.get();
+    }
+  }
+
+
+} 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org