You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by ji...@apache.org on 2021/08/06 21:07:35 UTC

[druid] branch master updated: Improve concurrency between DruidSchema and BrokerServerView (#11457)

This is an automated email from the ASF dual-hosted git repository.

jihoonson pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new e9d964d  Improve concurrency between DruidSchema and BrokerServerView (#11457)
e9d964d is described below

commit e9d964d504cb510226d58b2caa299cecfec99e15
Author: Jihoon Son <ji...@apache.org>
AuthorDate: Fri Aug 6 14:07:13 2021 -0700

    Improve concurrency between DruidSchema and BrokerServerView (#11457)
    
    * Improve concurrency between DruidSchema and BrokerServerView
    
    * unused imports and workaround for error prone faiure
    
    * count only known segments
    
    * add comments
---
 .../druid/client/SingleServerInventoryView.java    |   4 +-
 .../druid/sql/calcite/schema/DruidSchema.java      | 623 ++++++++++++++-------
 .../calcite/schema/DruidSchemaConcurrencyTest.java | 485 ++++++++++++++++
 .../druid/sql/calcite/schema/DruidSchemaTest.java  | 611 +++++++++++++++-----
 .../sql/calcite/schema/DruidSchemaTestCommon.java  | 137 +++++
 5 files changed, 1515 insertions(+), 345 deletions(-)

diff --git a/server/src/main/java/org/apache/druid/client/SingleServerInventoryView.java b/server/src/main/java/org/apache/druid/client/SingleServerInventoryView.java
index a7a4630..0b59137 100644
--- a/server/src/main/java/org/apache/druid/client/SingleServerInventoryView.java
+++ b/server/src/main/java/org/apache/druid/client/SingleServerInventoryView.java
@@ -120,13 +120,13 @@ public class SingleServerInventoryView extends AbstractCuratorServerInventoryVie
     segmentPredicates.remove(callback);
   }
 
-  static class FilteringSegmentCallback implements SegmentCallback
+  public static class FilteringSegmentCallback implements SegmentCallback
   {
 
     private final SegmentCallback callback;
     private final Predicate<Pair<DruidServerMetadata, DataSegment>> filter;
 
-    FilteringSegmentCallback(SegmentCallback callback, Predicate<Pair<DruidServerMetadata, DataSegment>> filter)
+    public FilteringSegmentCallback(SegmentCallback callback, Predicate<Pair<DruidServerMetadata, DataSegment>> filter)
     {
       this.callback = callback;
       this.filter = filter;
diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/schema/DruidSchema.java b/sql/src/main/java/org/apache/druid/sql/calcite/schema/DruidSchema.java
index 30a1576..e5846a9 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/schema/DruidSchema.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/schema/DruidSchema.java
@@ -70,7 +70,6 @@ import org.apache.druid.timeline.SegmentId;
 import java.io.IOException;
 import java.util.Comparator;
 import java.util.EnumSet;
-import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Optional;
@@ -79,6 +78,7 @@ import java.util.TreeMap;
 import java.util.TreeSet;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.ConcurrentSkipListMap;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.function.Function;
@@ -100,26 +100,82 @@ public class DruidSchema extends AbstractSchema
 
   private final QueryLifecycleFactory queryLifecycleFactory;
   private final PlannerConfig config;
+  // Escalator, so we can attach an authentication result to queries we generate.
+  private final Escalator escalator;
   private final SegmentManager segmentManager;
   private final JoinableFactory joinableFactory;
   private final ExecutorService cacheExec;
-  private final ConcurrentMap<String, DruidTable> tables;
+  private final ExecutorService callbackExec;
+
+  /**
+   * Map of DataSource -> DruidTable.
+   * This map can be accessed by {@link #cacheExec} and {@link #callbackExec} threads.
+   */
+  private final ConcurrentMap<String, DruidTable> tables = new ConcurrentHashMap<>();
+
+  /**
+   * DataSource -> Segment -> AvailableSegmentMetadata(contains RowSignature) for that segment.
+   * Use SortedMap for segments so they are merged in deterministic order, from older to newer.
+   *
+   * This map is updated by these two threads.
+   *
+   * - {@link #callbackExec} can update it in {@link #addSegment}, {@link #removeServerSegment},
+   *   and {@link #removeSegment}.
+   * - {@link #cacheExec} can update it in {@link #refreshSegmentsForDataSource}.
+   *
+   * While it is being updated, this map is read by these two types of thread.
+   *
+   * - {@link #cacheExec} can iterate all {@link AvailableSegmentMetadata}s per datasource.
+   *   See {@link #buildDruidTable}.
+   * - Query threads can create a snapshot of the entire map for processing queries on the system table.
+   *   See {@link #getSegmentMetadataSnapshot()}.
+   *
+   * As the access pattern of this map is read-intensive, we should minimize the contention between writers and readers.
+   * Since there are two threads that can update this map at the same time, those writers should lock the inner map
+   * first and then lock the entry before it updates segment metadata. This can be done using
+   * {@link ConcurrentMap#compute} as below. Note that, if you need to update the variables guarded by {@link #lock}
+   * inside of compute(), you should get the lock before calling compute() to keep the function executed in compute()
+   * not expensive.
+   *
+   * <pre>
+   *   segmentMedataInfo.compute(
+   *     datasourceParam,
+   *     (datasource, segmentsMap) -> {
+   *       if (segmentsMap == null) return null;
+   *       else {
+   *         segmentsMap.compute(
+   *           segmentIdParam,
+   *           (segmentId, segmentMetadata) -> {
+   *             // update segmentMetadata
+   *           }
+   *         );
+   *         return segmentsMap;
+   *       }
+   *     }
+   *   );
+   * </pre>
+   *
+   * Readers can simply delegate the locking to the concurrent map and iterate map entries.
+   */
+  private final ConcurrentHashMap<String, ConcurrentSkipListMap<SegmentId, AvailableSegmentMetadata>> segmentMetadataInfo
+      = new ConcurrentHashMap<>();
 
   // For awaitInitialization.
   private final CountDownLatch initialized = new CountDownLatch(1);
 
-  // Protects access to segmentSignatures, mutableSegments, segmentsNeedingRefresh, lastRefresh, isServerViewInitialized, segmentMetadata
+  /**
+   * This lock coordinates the access from multiple threads to those variables guarded by this lock.
+   * Currently, there are 2 threads that can access these variables.
+   *
+   * - {@link #callbackExec} executes the timeline callbacks whenever BrokerServerView changes.
+   * - {@link #cacheExec} periodically refreshes segment metadata and {@link DruidTable} if necessary
+   *   based on the information collected via timeline callbacks.
+   */
   private final Object lock = new Object();
 
-  // DataSource -> Segment -> AvailableSegmentMetadata(contains RowSignature) for that segment.
-  // Use TreeMap for segments so they are merged in deterministic order, from older to newer.
-  @GuardedBy("lock")
-  private final Map<String, TreeMap<SegmentId, AvailableSegmentMetadata>> segmentMetadataInfo = new HashMap<>();
-  private int totalSegments = 0;
-
   // All mutable segments.
   @GuardedBy("lock")
-  private final Set<SegmentId> mutableSegments = new TreeSet<>(SEGMENT_ORDER);
+  private final TreeSet<SegmentId> mutableSegments = new TreeSet<>(SEGMENT_ORDER);
 
   // All dataSources that need tables regenerated.
   @GuardedBy("lock")
@@ -129,18 +185,19 @@ public class DruidSchema extends AbstractSchema
   @GuardedBy("lock")
   private final TreeSet<SegmentId> segmentsNeedingRefresh = new TreeSet<>(SEGMENT_ORDER);
 
-  // Escalator, so we can attach an authentication result to queries we generate.
-  private final Escalator escalator;
-
   @GuardedBy("lock")
   private boolean refreshImmediately = false;
-  @GuardedBy("lock")
-  private long lastRefresh = 0L;
-  @GuardedBy("lock")
-  private long lastFailure = 0L;
+
   @GuardedBy("lock")
   private boolean isServerViewInitialized = false;
 
+  /**
+   * Counts the total number of known segments. This variable is used only for the segments table in the system schema
+   * to initialize a map with a more proper size when it creates a snapshot. As a result, it doesn't have to be exact,
+   * and thus there is no concurrency control for this variable.
+   */
+  private int totalSegments = 0;
+
   @Inject
   public DruidSchema(
       final QueryLifecycleFactory queryLifecycleFactory,
@@ -157,11 +214,11 @@ public class DruidSchema extends AbstractSchema
     this.joinableFactory = joinableFactory;
     this.config = Preconditions.checkNotNull(config, "config");
     this.cacheExec = Execs.singleThreaded("DruidSchema-Cache-%d");
-    this.tables = new ConcurrentHashMap<>();
+    this.callbackExec = Execs.singleThreaded("DruidSchema-Callback-%d");
     this.escalator = escalator;
 
     serverView.registerTimelineCallback(
-        Execs.directExecutor(),
+        callbackExec,
         new TimelineServerView.TimelineCallback()
         {
           @Override
@@ -207,6 +264,9 @@ public class DruidSchema extends AbstractSchema
   {
     cacheExec.submit(
         () -> {
+          long lastRefresh = 0L;
+          long lastFailure = 0L;
+
           try {
             while (!Thread.currentThread().isInterrupted()) {
               final Set<SegmentId> segmentsToRefresh = new TreeSet<>();
@@ -259,32 +319,7 @@ public class DruidSchema extends AbstractSchema
                   refreshImmediately = false;
                 }
 
-                // Refresh the segments.
-                final Set<SegmentId> refreshed = refreshSegments(segmentsToRefresh);
-
-                synchronized (lock) {
-                  // Add missing segments back to the refresh list.
-                  segmentsNeedingRefresh.addAll(Sets.difference(segmentsToRefresh, refreshed));
-
-                  // Compute the list of dataSources to rebuild tables for.
-                  dataSourcesToRebuild.addAll(dataSourcesNeedingRebuild);
-                  refreshed.forEach(segment -> dataSourcesToRebuild.add(segment.getDataSource()));
-                  dataSourcesNeedingRebuild.clear();
-
-                  lock.notifyAll();
-                }
-
-                // Rebuild the dataSources.
-                for (String dataSource : dataSourcesToRebuild) {
-                  final DruidTable druidTable = buildDruidTable(dataSource);
-                  final DruidTable oldTable = tables.put(dataSource, druidTable);
-                  final String description = druidTable.getDataSource().isGlobal() ? "global dataSource" : "dataSource";
-                  if (oldTable == null || !oldTable.getRowSignature().equals(druidTable.getRowSignature())) {
-                    log.info("%s [%s] has new signature: %s.", description, dataSource, druidTable.getRowSignature());
-                  } else {
-                    log.debug("%s [%s] signature is unchanged.", description, dataSource);
-                  }
-                }
+                refresh(segmentsToRefresh, dataSourcesToRebuild);
 
                 initialized.countDown();
               }
@@ -300,7 +335,6 @@ public class DruidSchema extends AbstractSchema
                   segmentsNeedingRefresh.addAll(segmentsToRefresh);
                   dataSourcesNeedingRebuild.addAll(dataSourcesToRebuild);
                   lastFailure = System.currentTimeMillis();
-                  lock.notifyAll();
                 }
               }
             }
@@ -328,10 +362,40 @@ public class DruidSchema extends AbstractSchema
     }
   }
 
+  @VisibleForTesting
+  void refresh(final Set<SegmentId> segmentsToRefresh, final Set<String> dataSourcesToRebuild) throws IOException
+  {
+    // Refresh the segments.
+    final Set<SegmentId> refreshed = refreshSegments(segmentsToRefresh);
+
+    synchronized (lock) {
+      // Add missing segments back to the refresh list.
+      segmentsNeedingRefresh.addAll(Sets.difference(segmentsToRefresh, refreshed));
+
+      // Compute the list of dataSources to rebuild tables for.
+      dataSourcesToRebuild.addAll(dataSourcesNeedingRebuild);
+      refreshed.forEach(segment -> dataSourcesToRebuild.add(segment.getDataSource()));
+      dataSourcesNeedingRebuild.clear();
+    }
+
+    // Rebuild the dataSources.
+    for (String dataSource : dataSourcesToRebuild) {
+      final DruidTable druidTable = buildDruidTable(dataSource);
+      final DruidTable oldTable = tables.put(dataSource, druidTable);
+      final String description = druidTable.getDataSource().isGlobal() ? "global dataSource" : "dataSource";
+      if (oldTable == null || !oldTable.getRowSignature().equals(druidTable.getRowSignature())) {
+        log.info("%s [%s] has new signature: %s.", description, dataSource, druidTable.getRowSignature());
+      } else {
+        log.debug("%s [%s] signature is unchanged.", description, dataSource);
+      }
+    }
+  }
+
   @LifecycleStop
   public void stop()
   {
     cacheExec.shutdownNow();
+    callbackExec.shutdownNow();
   }
 
   public void awaitInitialization() throws InterruptedException
@@ -348,54 +412,66 @@ public class DruidSchema extends AbstractSchema
   @VisibleForTesting
   void addSegment(final DruidServerMetadata server, final DataSegment segment)
   {
+    // Get lock first so that we won't wait in ConcurrentMap.compute().
     synchronized (lock) {
       // someday we could hypothetically remove broker special casing, whenever BrokerServerView supports tracking
       // broker served segments in the timeline, to ensure that removeSegment the event is triggered accurately
       if (server.getType().equals(ServerType.BROKER)) {
         // a segment on a broker means a broadcast datasource, skip metadata because we'll also see this segment on the
         // historical, however mark the datasource for refresh because it needs to be globalized
-        dataSourcesNeedingRebuild.add(segment.getDataSource());
+        markDataSourceAsNeedRebuild(segment.getDataSource());
       } else {
-        final Map<SegmentId, AvailableSegmentMetadata> knownSegments = segmentMetadataInfo.get(segment.getDataSource());
-        AvailableSegmentMetadata segmentMetadata = knownSegments != null ? knownSegments.get(segment.getId()) : null;
-        if (segmentMetadata == null) {
-          // segmentReplicatable is used to determine if segments are served by historical or realtime servers
-          long isRealtime = server.isSegmentReplicationTarget() ? 0 : 1;
-          segmentMetadata = AvailableSegmentMetadata.builder(
-              segment,
-              isRealtime,
-              ImmutableSet.of(server),
-              null,
-              DEFAULT_NUM_ROWS
-          ).build();
-          // Unknown segment.
-          setAvailableSegmentMetadata(segment.getId(), segmentMetadata);
-          segmentsNeedingRefresh.add(segment.getId());
-          if (!server.isSegmentReplicationTarget()) {
-            log.debug("Added new mutable segment[%s].", segment.getId());
-            mutableSegments.add(segment.getId());
-          } else {
-            log.debug("Added new immutable segment[%s].", segment.getId());
-          }
-        } else {
-          final Set<DruidServerMetadata> segmentServers = segmentMetadata.getReplicas();
-          final ImmutableSet<DruidServerMetadata> servers = new ImmutableSet.Builder<DruidServerMetadata>()
-              .addAll(segmentServers)
-              .add(server)
-              .build();
-          final AvailableSegmentMetadata metadataWithNumReplicas = AvailableSegmentMetadata
-              .from(segmentMetadata)
-              .withReplicas(servers)
-              .withRealtime(recomputeIsRealtime(servers))
-              .build();
-          knownSegments.put(segment.getId(), metadataWithNumReplicas);
-          if (server.isSegmentReplicationTarget()) {
-            // If a segment shows up on a replicatable (historical) server at any point, then it must be immutable,
-            // even if it's also available on non-replicatable (realtime) servers.
-            mutableSegments.remove(segment.getId());
-            log.debug("Segment[%s] has become immutable.", segment.getId());
-          }
-        }
+        segmentMetadataInfo.compute(
+            segment.getDataSource(),
+            (datasource, segmentsMap) -> {
+              if (segmentsMap == null) {
+                segmentsMap = new ConcurrentSkipListMap<>(SEGMENT_ORDER);
+              }
+              segmentsMap.compute(
+                  segment.getId(),
+                  (segmentId, segmentMetadata) -> {
+                    if (segmentMetadata == null) {
+                      // Unknown segment.
+                      totalSegments++;
+                      // segmentReplicatable is used to determine if segments are served by historical or realtime servers
+                      long isRealtime = server.isSegmentReplicationTarget() ? 0 : 1;
+                      segmentMetadata = AvailableSegmentMetadata
+                          .builder(segment, isRealtime, ImmutableSet.of(server), null, DEFAULT_NUM_ROWS)
+                          .build();
+                      markSegmentAsNeedRefresh(segment.getId());
+                      if (!server.isSegmentReplicationTarget()) {
+                        log.debug("Added new mutable segment[%s].", segment.getId());
+                        markSegmentAsMutable(segment.getId());
+                      } else {
+                        log.debug("Added new immutable segment[%s].", segment.getId());
+                      }
+                    } else {
+                      // We know this segment.
+                      final Set<DruidServerMetadata> segmentServers = segmentMetadata.getReplicas();
+                      final ImmutableSet<DruidServerMetadata> servers = new ImmutableSet.Builder<DruidServerMetadata>()
+                          .addAll(segmentServers)
+                          .add(server)
+                          .build();
+                      segmentMetadata = AvailableSegmentMetadata
+                          .from(segmentMetadata)
+                          .withReplicas(servers)
+                          .withRealtime(recomputeIsRealtime(servers))
+                          .build();
+                      if (server.isSegmentReplicationTarget()) {
+                        // If a segment shows up on a replicatable (historical) server at any point, then it must be immutable,
+                        // even if it's also available on non-replicatable (realtime) servers.
+                        unmarkSegmentAsMutable(segment.getId());
+                        log.debug("Segment[%s] has become immutable.", segment.getId());
+                      }
+                    }
+                    assert segmentMetadata != null;
+                    return segmentMetadata;
+                  }
+              );
+
+              return segmentsMap;
+            }
+        );
       }
       if (!tables.containsKey(segment.getDataSource())) {
         refreshImmediately = true;
@@ -408,25 +484,36 @@ public class DruidSchema extends AbstractSchema
   @VisibleForTesting
   void removeSegment(final DataSegment segment)
   {
+    // Get lock first so that we won't wait in ConcurrentMap.compute().
     synchronized (lock) {
       log.debug("Segment[%s] is gone.", segment.getId());
 
       segmentsNeedingRefresh.remove(segment.getId());
-      mutableSegments.remove(segment.getId());
-
-      final Map<SegmentId, AvailableSegmentMetadata> dataSourceSegments =
-          segmentMetadataInfo.get(segment.getDataSource());
-      if (dataSourceSegments.remove(segment.getId()) != null) {
-        totalSegments--;
-      }
-
-      if (dataSourceSegments.isEmpty()) {
-        segmentMetadataInfo.remove(segment.getDataSource());
-        tables.remove(segment.getDataSource());
-        log.info("dataSource[%s] no longer exists, all metadata removed.", segment.getDataSource());
-      } else {
-        dataSourcesNeedingRebuild.add(segment.getDataSource());
-      }
+      unmarkSegmentAsMutable(segment.getId());
+
+      segmentMetadataInfo.compute(
+          segment.getDataSource(),
+          (dataSource, segmentsMap) -> {
+            if (segmentsMap == null) {
+              log.warn("Unknown segment[%s] was removed from the cluster. Ignoring this event.", segment.getId());
+              return null;
+            } else {
+              if (segmentsMap.remove(segment.getId()) == null) {
+                log.warn("Unknown segment[%s] was removed from the cluster. Ignoring this event.", segment.getId());
+              } else {
+                totalSegments--;
+              }
+              if (segmentsMap.isEmpty()) {
+                tables.remove(segment.getDataSource());
+                log.info("dataSource[%s] no longer exists, all metadata removed.", segment.getDataSource());
+                return null;
+              } else {
+                markDataSourceAsNeedRebuild(segment.getDataSource());
+                return segmentsMap;
+              }
+            }
+          }
+      );
 
       lock.notifyAll();
     }
@@ -435,38 +522,95 @@ public class DruidSchema extends AbstractSchema
   @VisibleForTesting
   void removeServerSegment(final DruidServerMetadata server, final DataSegment segment)
   {
+    // Get lock first so that we won't wait in ConcurrentMap.compute().
     synchronized (lock) {
       log.debug("Segment[%s] is gone from server[%s]", segment.getId(), server.getName());
-      final Map<SegmentId, AvailableSegmentMetadata> knownSegments = segmentMetadataInfo.get(segment.getDataSource());
+      segmentMetadataInfo.compute(
+          segment.getDataSource(),
+          (datasource, knownSegments) -> {
+            if (knownSegments == null) {
+              log.warn(
+                  "Unknown segment[%s] is removed from server[%s]. Ignoring this event",
+                  segment.getId(),
+                  server.getHost()
+              );
+              return null;
+            }
+
+            if (server.getType().equals(ServerType.BROKER)) {
+              // for brokers, if the segment drops from all historicals before the broker this could be null.
+              if (!knownSegments.isEmpty()) {
+                // a segment on a broker means a broadcast datasource, skip metadata because we'll also see this segment on the
+                // historical, however mark the datasource for refresh because it might no longer be broadcast or something
+                markDataSourceAsNeedRebuild(segment.getDataSource());
+              }
+            } else {
+              knownSegments.compute(
+                  segment.getId(),
+                  (segmentId, segmentMetadata) -> {
+                    if (segmentMetadata == null) {
+                      log.warn(
+                          "Unknown segment[%s] is removed from server[%s]. Ignoring this event",
+                          segment.getId(),
+                          server.getHost()
+                      );
+                      return null;
+                    } else {
+                      final Set<DruidServerMetadata> segmentServers = segmentMetadata.getReplicas();
+                      final ImmutableSet<DruidServerMetadata> servers = FluentIterable
+                          .from(segmentServers)
+                          .filter(Predicates.not(Predicates.equalTo(server)))
+                          .toSet();
+                      return AvailableSegmentMetadata
+                          .from(segmentMetadata)
+                          .withReplicas(servers)
+                          .withRealtime(recomputeIsRealtime(servers))
+                          .build();
+                    }
+                  }
+              );
+            }
+            if (knownSegments.isEmpty()) {
+              return null;
+            } else {
+              return knownSegments;
+            }
+          }
+      );
 
-      // someday we could hypothetically remove broker special casing, whenever BrokerServerView supports tracking
-      // broker served segments in the timeline, to ensure that removeSegment the event is triggered accurately
-      if (server.getType().equals(ServerType.BROKER)) {
-        // for brokers, if the segment drops from all historicals before the broker this could be null.
-        if (knownSegments != null && !knownSegments.isEmpty()) {
-          // a segment on a broker means a broadcast datasource, skip metadata because we'll also see this segment on the
-          // historical, however mark the datasource for refresh because it might no longer be broadcast or something
-          dataSourcesNeedingRebuild.add(segment.getDataSource());
-        }
-      } else {
-        final AvailableSegmentMetadata segmentMetadata = knownSegments.get(segment.getId());
-        final Set<DruidServerMetadata> segmentServers = segmentMetadata.getReplicas();
-        final ImmutableSet<DruidServerMetadata> servers = FluentIterable
-            .from(segmentServers)
-            .filter(Predicates.not(Predicates.equalTo(server)))
-            .toSet();
-
-        final AvailableSegmentMetadata metadataWithNumReplicas = AvailableSegmentMetadata
-            .from(segmentMetadata)
-            .withReplicas(servers)
-            .withRealtime(recomputeIsRealtime(servers))
-            .build();
-        knownSegments.put(segment.getId(), metadataWithNumReplicas);
-      }
       lock.notifyAll();
     }
   }
 
+  private void markSegmentAsNeedRefresh(SegmentId segmentId)
+  {
+    synchronized (lock) {
+      segmentsNeedingRefresh.add(segmentId);
+    }
+  }
+
+  private void markSegmentAsMutable(SegmentId segmentId)
+  {
+    synchronized (lock) {
+      mutableSegments.add(segmentId);
+    }
+  }
+
+  private void unmarkSegmentAsMutable(SegmentId segmentId)
+  {
+    synchronized (lock) {
+      mutableSegments.remove(segmentId);
+    }
+  }
+
+  @VisibleForTesting
+  void markDataSourceAsNeedRebuild(String datasource)
+  {
+    synchronized (lock) {
+      dataSourcesNeedingRebuild.add(datasource);
+    }
+  }
+
   /**
    * Attempt to refresh "segmentSignatures" for a set of segments. Returns the set of segments actually refreshed,
    * which may be a subset of the asked-for set.
@@ -494,14 +638,19 @@ public class DruidSchema extends AbstractSchema
 
   private long recomputeIsRealtime(ImmutableSet<DruidServerMetadata> servers)
   {
+    if (servers.isEmpty()) {
+      return 0;
+    }
     final Optional<DruidServerMetadata> historicalServer = servers
         .stream()
+        // Ideally, this filter should have checked whether it's a broadcast segment loaded in brokers.
+        // However, we don't current track of the broadcast segments loaded in brokers, so this filter is still valid.
+        // See addSegment(), removeServerSegment(), and removeSegment()
         .filter(metadata -> metadata.getType().equals(ServerType.HISTORICAL))
         .findAny();
 
     // if there is any historical server in the replicas, isRealtime flag should be unset
-    final long isRealtime = historicalServer.isPresent() ? 0 : 1;
-    return isRealtime;
+    return historicalServer.isPresent() ? 0 : 1;
   }
 
   /**
@@ -540,33 +689,46 @@ public class DruidSchema extends AbstractSchema
         if (segmentId == null) {
           log.warn("Got analysis for segment[%s] we didn't ask for, ignoring.", analysis.getId());
         } else {
-          synchronized (lock) {
-            final RowSignature rowSignature = analysisToRowSignature(analysis);
-            log.debug("Segment[%s] has signature[%s].", segmentId, rowSignature);
-            final Map<SegmentId, AvailableSegmentMetadata> dataSourceSegments = segmentMetadataInfo.get(dataSource);
-            if (dataSourceSegments == null) {
-              // Datasource may have been removed or become unavailable while this refresh was ongoing.
-              log.warn(
-                  "No segment map found with datasource[%s], skipping refresh of segment[%s]",
-                  dataSource,
-                  segmentId
-              );
-            } else {
-              final AvailableSegmentMetadata segmentMetadata = dataSourceSegments.get(segmentId);
-              if (segmentMetadata == null) {
-                log.warn("No segment[%s] found, skipping refresh", segmentId);
-              } else {
-                final AvailableSegmentMetadata updatedSegmentMetadata = AvailableSegmentMetadata
-                    .from(segmentMetadata)
-                    .withRowSignature(rowSignature)
-                    .withNumRows(analysis.getNumRows())
-                    .build();
-                dataSourceSegments.put(segmentId, updatedSegmentMetadata);
-                setAvailableSegmentMetadata(segmentId, updatedSegmentMetadata);
-                retVal.add(segmentId);
+          final RowSignature rowSignature = analysisToRowSignature(analysis);
+          log.debug("Segment[%s] has signature[%s].", segmentId, rowSignature);
+          segmentMetadataInfo.compute(
+              dataSource,
+              (datasourceKey, dataSourceSegments) -> {
+                if (dataSourceSegments == null) {
+                  // Datasource may have been removed or become unavailable while this refresh was ongoing.
+                  log.warn(
+                      "No segment map found with datasource[%s], skipping refresh of segment[%s]",
+                      datasourceKey,
+                      segmentId
+                  );
+                  return null;
+                } else {
+                  dataSourceSegments.compute(
+                      segmentId,
+                      (segmentIdKey, segmentMetadata) -> {
+                        if (segmentMetadata == null) {
+                          log.warn("No segment[%s] found, skipping refresh", segmentId);
+                          return null;
+                        } else {
+                          final AvailableSegmentMetadata updatedSegmentMetadata = AvailableSegmentMetadata
+                              .from(segmentMetadata)
+                              .withRowSignature(rowSignature)
+                              .withNumRows(analysis.getNumRows())
+                              .build();
+                          retVal.add(segmentId);
+                          return updatedSegmentMetadata;
+                        }
+                      }
+                  );
+
+                  if (dataSourceSegments.isEmpty()) {
+                    return null;
+                  } else {
+                    return dataSourceSegments;
+                  }
+                }
               }
-            }
-          }
+          );
         }
 
         yielder = yielder.next(null);
@@ -588,60 +750,88 @@ public class DruidSchema extends AbstractSchema
   }
 
   @VisibleForTesting
-  void setAvailableSegmentMetadata(final SegmentId segmentId, final AvailableSegmentMetadata availableSegmentMetadata)
+  DruidTable buildDruidTable(final String dataSource)
   {
-    synchronized (lock) {
-      TreeMap<SegmentId, AvailableSegmentMetadata> dataSourceSegments = segmentMetadataInfo.computeIfAbsent(
-          segmentId.getDataSource(),
-          x -> new TreeMap<>(SEGMENT_ORDER)
-      );
-      if (dataSourceSegments.put(segmentId, availableSegmentMetadata) == null) {
-        totalSegments++;
+    ConcurrentSkipListMap<SegmentId, AvailableSegmentMetadata> segmentsMap = segmentMetadataInfo.get(dataSource);
+    final Map<String, ValueType> columnTypes = new TreeMap<>();
+
+    if (segmentsMap != null) {
+      for (AvailableSegmentMetadata availableSegmentMetadata : segmentsMap.values()) {
+        final RowSignature rowSignature = availableSegmentMetadata.getRowSignature();
+        if (rowSignature != null) {
+          for (String column : rowSignature.getColumnNames()) {
+            // Newer column types should override older ones.
+            final ValueType columnType =
+                rowSignature.getColumnType(column)
+                            .orElseThrow(() -> new ISE("Encountered null type for column[%s]", column));
+
+            columnTypes.putIfAbsent(column, columnType);
+          }
+        }
       }
     }
+
+    final RowSignature.Builder builder = RowSignature.builder();
+    columnTypes.forEach(builder::add);
+
+    final TableDataSource tableDataSource;
+
+    // to be a GlobalTableDataSource instead of a TableDataSource, it must appear on all servers (inferred by existing
+    // in the segment cache, which in this case belongs to the broker meaning only broadcast segments live here)
+    // to be joinable, it must be possibly joinable according to the factory. we only consider broadcast datasources
+    // at this time, and isGlobal is currently strongly coupled with joinable, so only make a global table datasource
+    // if also joinable
+    final GlobalTableDataSource maybeGlobal = new GlobalTableDataSource(dataSource);
+    final boolean isJoinable = joinableFactory.isDirectlyJoinable(maybeGlobal);
+    final boolean isBroadcast = segmentManager.getDataSourceNames().contains(dataSource);
+    if (isBroadcast && isJoinable) {
+      tableDataSource = maybeGlobal;
+    } else {
+      tableDataSource = new TableDataSource(dataSource);
+    }
+    return new DruidTable(tableDataSource, builder.build(), isJoinable, isBroadcast);
   }
 
-  protected DruidTable buildDruidTable(final String dataSource)
+  @VisibleForTesting
+  Map<SegmentId, AvailableSegmentMetadata> getSegmentMetadataSnapshot()
+  {
+    final Map<SegmentId, AvailableSegmentMetadata> segmentMetadata = Maps.newHashMapWithExpectedSize(totalSegments);
+    for (ConcurrentSkipListMap<SegmentId, AvailableSegmentMetadata> val : segmentMetadataInfo.values()) {
+      segmentMetadata.putAll(val);
+    }
+    return segmentMetadata;
+  }
+
+  /**
+   * Returns total number of segments. This method doesn't use the lock intentionally to avoid expensive contention.
+   * As a result, the returned value might be inexact.
+   */
+  int getTotalSegments()
+  {
+    return totalSegments;
+  }
+
+  @VisibleForTesting
+  TreeSet<SegmentId> getSegmentsNeedingRefresh()
   {
     synchronized (lock) {
-      final Map<SegmentId, AvailableSegmentMetadata> segmentMap = segmentMetadataInfo.get(dataSource);
-      final Map<String, ValueType> columnTypes = new TreeMap<>();
-
-      if (segmentMap != null) {
-        for (AvailableSegmentMetadata availableSegmentMetadata : segmentMap.values()) {
-          final RowSignature rowSignature = availableSegmentMetadata.getRowSignature();
-          if (rowSignature != null) {
-            for (String column : rowSignature.getColumnNames()) {
-              // Newer column types should override older ones.
-              final ValueType columnType =
-                  rowSignature.getColumnType(column)
-                              .orElseThrow(() -> new ISE("Encountered null type for column[%s]", column));
-
-              columnTypes.putIfAbsent(column, columnType);
-            }
-          }
-        }
-      }
+      return segmentsNeedingRefresh;
+    }
+  }
 
-      final RowSignature.Builder builder = RowSignature.builder();
-      columnTypes.forEach(builder::add);
-
-      final TableDataSource tableDataSource;
-
-      // to be a GlobalTableDataSource instead of a TableDataSource, it must appear on all servers (inferred by existing
-      // in the segment cache, which in this case belongs to the broker meaning only broadcast segments live here)
-      // to be joinable, it must be possibly joinable according to the factory. we only consider broadcast datasources
-      // at this time, and isGlobal is currently strongly coupled with joinable, so only make a global table datasource
-      // if also joinable
-      final GlobalTableDataSource maybeGlobal = new GlobalTableDataSource(dataSource);
-      final boolean isJoinable = joinableFactory.isDirectlyJoinable(maybeGlobal);
-      final boolean isBroadcast = segmentManager.getDataSourceNames().contains(dataSource);
-      if (isBroadcast && isJoinable) {
-        tableDataSource = maybeGlobal;
-      } else {
-        tableDataSource = new TableDataSource(dataSource);
-      }
-      return new DruidTable(tableDataSource, builder.build(), isJoinable, isBroadcast);
+  @VisibleForTesting
+  TreeSet<SegmentId> getMutableSegments()
+  {
+    synchronized (lock) {
+      return mutableSegments;
+    }
+  }
+
+  @VisibleForTesting
+  Set<String> getDataSourcesNeedingRebuild()
+  {
+    synchronized (lock) {
+      return dataSourcesNeedingRebuild;
     }
   }
 
@@ -700,20 +890,31 @@ public class DruidSchema extends AbstractSchema
     return rowSignatureBuilder.build();
   }
 
-  Map<SegmentId, AvailableSegmentMetadata> getSegmentMetadataSnapshot()
+  /**
+   * This method is not thread-safe and must be used only in unit tests.
+   */
+  @VisibleForTesting
+  void setAvailableSegmentMetadata(final SegmentId segmentId, final AvailableSegmentMetadata availableSegmentMetadata)
   {
-    synchronized (lock) {
-      final Map<SegmentId, AvailableSegmentMetadata> segmentMetadata = Maps.newHashMapWithExpectedSize(
-          segmentMetadataInfo.values().stream().mapToInt(v -> v.size()).sum());
-      for (TreeMap<SegmentId, AvailableSegmentMetadata> val : segmentMetadataInfo.values()) {
-        segmentMetadata.putAll(val);
-      }
-      return segmentMetadata;
+    final ConcurrentSkipListMap<SegmentId, AvailableSegmentMetadata> dataSourceSegments = segmentMetadataInfo
+        .computeIfAbsent(
+            segmentId.getDataSource(),
+            k -> new ConcurrentSkipListMap<>(SEGMENT_ORDER)
+        );
+    if (dataSourceSegments.put(segmentId, availableSegmentMetadata) == null) {
+      totalSegments++;
     }
   }
 
-  int getTotalSegments()
+  /**
+   * This is a helper method for unit tests to emulate heavy work done with {@link #lock}.
+   * It must be used only in unit tests.
+   */
+  @VisibleForTesting
+  void doInLock(Runnable runnable)
   {
-    return totalSegments;
+    synchronized (lock) {
+      runnable.run();
+    }
   }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaConcurrencyTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaConcurrencyTest.java
new file mode 100644
index 0000000..129bf9d
--- /dev/null
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaConcurrencyTest.java
@@ -0,0 +1,485 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.schema;
+
+import com.google.common.base.Predicate;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
+import org.apache.druid.client.BrokerSegmentWatcherConfig;
+import org.apache.druid.client.BrokerServerView;
+import org.apache.druid.client.DruidServer;
+import org.apache.druid.client.FilteredServerInventoryView;
+import org.apache.druid.client.ServerView.CallbackAction;
+import org.apache.druid.client.ServerView.SegmentCallback;
+import org.apache.druid.client.ServerView.ServerRemovedCallback;
+import org.apache.druid.client.SingleServerInventoryView.FilteringSegmentCallback;
+import org.apache.druid.client.TimelineServerView.TimelineCallback;
+import org.apache.druid.client.selector.HighestPriorityTierSelectorStrategy;
+import org.apache.druid.client.selector.RandomServerSelectorStrategy;
+import org.apache.druid.jackson.DefaultObjectMapper;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.NonnullPair;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.java.util.common.concurrent.Execs;
+import org.apache.druid.java.util.http.client.HttpClient;
+import org.apache.druid.query.QueryToolChestWarehouse;
+import org.apache.druid.query.QueryWatcher;
+import org.apache.druid.query.TableDataSource;
+import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.planning.DataSourceAnalysis;
+import org.apache.druid.segment.IndexBuilder;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.incremental.IncrementalIndexSchema;
+import org.apache.druid.segment.join.MapJoinableFactory;
+import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
+import org.apache.druid.server.coordination.DruidServerMetadata;
+import org.apache.druid.server.coordination.ServerType;
+import org.apache.druid.server.metrics.NoopServiceEmitter;
+import org.apache.druid.server.security.NoopEscalator;
+import org.apache.druid.sql.calcite.table.DruidTable;
+import org.apache.druid.sql.calcite.util.CalciteTests;
+import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
+import org.apache.druid.timeline.DataSegment;
+import org.apache.druid.timeline.DataSegment.PruneSpecsHolder;
+import org.apache.druid.timeline.SegmentId;
+import org.apache.druid.timeline.partition.NumberedShardSpec;
+import org.easymock.EasyMock;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import javax.annotation.Nullable;
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Executor;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.stream.Collectors;
+
+public class DruidSchemaConcurrencyTest extends DruidSchemaTestCommon
+{
+  private static final String DATASOURCE = "datasource";
+
+  private File tmpDir;
+  private SpecificSegmentsQuerySegmentWalker walker;
+  private TestServerInventoryView inventoryView;
+  private BrokerServerView serverView;
+  private DruidSchema schema;
+  private ExecutorService exec;
+
+  @Before
+  public void setUp() throws Exception
+  {
+    tmpDir = temporaryFolder.newFolder();
+    walker = new SpecificSegmentsQuerySegmentWalker(conglomerate);
+    inventoryView = new TestServerInventoryView();
+    serverView = newBrokerServerView(inventoryView);
+    inventoryView.init();
+    serverView.awaitInitialization();
+    exec = Execs.multiThreaded(4, "DruidSchemaConcurrencyTest-%d");
+  }
+
+  @After
+  public void tearDown() throws Exception
+  {
+    exec.shutdownNow();
+    walker.close();
+  }
+
+  /**
+   * This tests the contention between 3 components, DruidSchema, InventoryView, and BrokerServerView.
+   * It first triggers refreshing DruidSchema. To mimic some heavy work done with {@link DruidSchema#lock},
+   * {@link DruidSchema#buildDruidTable} is overriden to sleep before doing real work. While refreshing DruidSchema,
+   * more new segments are added to InventoryView, which triggers updates of BrokerServerView. Finally, while
+   * BrokerServerView is updated, {@link BrokerServerView#getTimeline} is continuously called to mimic user query
+   * processing. All these calls must return without heavy contention.
+   */
+  @Test(timeout = 30000L)
+  public void testDruidSchemaRefreshAndInventoryViewAddSegmentAndBrokerServerViewGetTimeline()
+      throws InterruptedException, ExecutionException, TimeoutException
+  {
+    schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      DruidTable buildDruidTable(final String dataSource)
+      {
+        doInLock(() -> {
+          try {
+            // Mimic some heavy work done in lock in DruidSchema
+            Thread.sleep(5000);
+          }
+          catch (InterruptedException e) {
+            throw new RuntimeException(e);
+          }
+        });
+        return super.buildDruidTable(dataSource);
+      }
+    };
+
+    int numExistingSegments = 100;
+    int numServers = 19;
+    CountDownLatch segmentLoadLatch = new CountDownLatch(numExistingSegments);
+    serverView.registerTimelineCallback(
+        Execs.directExecutor(),
+        new TimelineCallback()
+        {
+          @Override
+          public CallbackAction timelineInitialized()
+          {
+            return CallbackAction.CONTINUE;
+          }
+
+          @Override
+          public CallbackAction segmentAdded(DruidServerMetadata server, DataSegment segment)
+          {
+            segmentLoadLatch.countDown();
+            return CallbackAction.CONTINUE;
+          }
+
+          @Override
+          public CallbackAction segmentRemoved(DataSegment segment)
+          {
+            return CallbackAction.CONTINUE;
+          }
+
+          @Override
+          public CallbackAction serverSegmentRemoved(DruidServerMetadata server, DataSegment segment)
+          {
+            return CallbackAction.CONTINUE;
+          }
+        }
+    );
+    addSegmentsToCluster(0, numServers, numExistingSegments);
+    // Wait for all segments to be loaded in BrokerServerView
+    Assert.assertTrue(segmentLoadLatch.await(5, TimeUnit.SECONDS));
+
+    // Trigger refresh of DruidSchema. This will internally run the heavy work mimicked by the overriden buildDruidTable
+    Future refreshFuture = exec.submit(() -> {
+      schema.refresh(
+          walker.getSegments().stream().map(DataSegment::getId).collect(Collectors.toSet()),
+          Sets.newHashSet(DATASOURCE)
+      );
+      return null;
+    });
+
+    // Trigger updates of BrokerServerView. This should be done asynchronously.
+    addSegmentsToCluster(numExistingSegments, numServers, 50); // add completely new segments
+    addReplicasToCluster(1, numServers, 30); // add replicas of the first 30 segments.
+    // for the first 30 segments, we will still have replicas.
+    // for the other 20 segments, they will be completely removed from the cluster.
+    removeSegmentsFromCluster(numServers, 50);
+    Assert.assertFalse(refreshFuture.isDone());
+
+    for (int i = 0; i < 1000; i++) {
+      boolean hasTimeline = exec.submit(
+          () -> serverView.getTimeline(DataSourceAnalysis.forDataSource(new TableDataSource(DATASOURCE)))
+                          .isPresent()
+      ).get(100, TimeUnit.MILLISECONDS);
+      Assert.assertTrue(hasTimeline);
+      // We want to call getTimeline while BrokerServerView is being updated. Sleep might help with timing.
+      Thread.sleep(2);
+    }
+
+    refreshFuture.get(10, TimeUnit.SECONDS);
+  }
+
+  /**
+   * This tests the contention between 2 methods of DruidSchema, {@link DruidSchema#refresh} and
+   * {@link DruidSchema#getSegmentMetadataSnapshot()}. It first triggers refreshing DruidSchema.
+   * To mimic some heavy work done with {@link DruidSchema#lock}, {@link DruidSchema#buildDruidTable} is overriden
+   * to sleep before doing real work. While refreshing DruidSchema, getSegmentMetadataSnapshot() is continuously
+   * called to mimic reading the segments table of SystemSchema. All these calls must return without heavy contention.
+   */
+  @Test(timeout = 30000L)
+  public void testDruidSchemaRefreshAndDruidSchemaGetSegmentMetadata()
+      throws InterruptedException, ExecutionException, TimeoutException
+  {
+    schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      DruidTable buildDruidTable(final String dataSource)
+      {
+        doInLock(() -> {
+          try {
+            // Mimic some heavy work done in lock in DruidSchema
+            Thread.sleep(5000);
+          }
+          catch (InterruptedException e) {
+            throw new RuntimeException(e);
+          }
+        });
+        return super.buildDruidTable(dataSource);
+      }
+    };
+
+    int numExistingSegments = 100;
+    int numServers = 19;
+    CountDownLatch segmentLoadLatch = new CountDownLatch(numExistingSegments);
+    serverView.registerTimelineCallback(
+        Execs.directExecutor(),
+        new TimelineCallback()
+        {
+          @Override
+          public CallbackAction timelineInitialized()
+          {
+            return CallbackAction.CONTINUE;
+          }
+
+          @Override
+          public CallbackAction segmentAdded(DruidServerMetadata server, DataSegment segment)
+          {
+            segmentLoadLatch.countDown();
+            return CallbackAction.CONTINUE;
+          }
+
+          @Override
+          public CallbackAction segmentRemoved(DataSegment segment)
+          {
+            return CallbackAction.CONTINUE;
+          }
+
+          @Override
+          public CallbackAction serverSegmentRemoved(DruidServerMetadata server, DataSegment segment)
+          {
+            return CallbackAction.CONTINUE;
+          }
+        }
+    );
+    addSegmentsToCluster(0, numServers, numExistingSegments);
+    // Wait for all segments to be loaded in BrokerServerView
+    Assert.assertTrue(segmentLoadLatch.await(5, TimeUnit.SECONDS));
+
+    // Trigger refresh of DruidSchema. This will internally run the heavy work mimicked by the overriden buildDruidTable
+    Future refreshFuture = exec.submit(() -> {
+      schema.refresh(
+          walker.getSegments().stream().map(DataSegment::getId).collect(Collectors.toSet()),
+          Sets.newHashSet(DATASOURCE)
+      );
+      return null;
+    });
+    Assert.assertFalse(refreshFuture.isDone());
+
+    for (int i = 0; i < 1000; i++) {
+      Map<SegmentId, AvailableSegmentMetadata> segmentsMetadata = exec.submit(
+          () -> schema.getSegmentMetadataSnapshot()
+      ).get(100, TimeUnit.MILLISECONDS);
+      Assert.assertFalse(segmentsMetadata.isEmpty());
+      // We want to call getTimeline while refreshing. Sleep might help with timing.
+      Thread.sleep(2);
+    }
+
+    refreshFuture.get(10, TimeUnit.SECONDS);
+  }
+
+  private void addSegmentsToCluster(int partitionIdStart, int numServers, int numSegments)
+  {
+    for (int i = 0; i < numSegments; i++) {
+      DataSegment segment = newSegment(i + partitionIdStart);
+      QueryableIndex index = newQueryableIndex(i + partitionIdStart);
+      walker.add(segment, index);
+      int serverIndex = i % numServers;
+      inventoryView.addServerSegment(newServer("server_" + serverIndex), segment);
+    }
+  }
+
+  private void addReplicasToCluster(int serverIndexOffFrom, int numServers, int numSegments)
+  {
+    for (int i = 0; i < numSegments; i++) {
+      DataSegment segment = newSegment(i);
+      int serverIndex = i % numServers + serverIndexOffFrom;
+      serverIndex = serverIndex < numServers ? serverIndex : serverIndex - numServers;
+      inventoryView.addServerSegment(newServer("server_" + serverIndex), segment);
+    }
+  }
+
+  private void removeSegmentsFromCluster(int numServers, int numSegments)
+  {
+    for (int i = 0; i < numSegments; i++) {
+      DataSegment segment = newSegment(i);
+      int serverIndex = i % numServers;
+      inventoryView.removeServerSegment(newServer("server_" + serverIndex), segment);
+    }
+  }
+
+  private static BrokerServerView newBrokerServerView(FilteredServerInventoryView baseView)
+  {
+    return new BrokerServerView(
+        EasyMock.createMock(QueryToolChestWarehouse.class),
+        EasyMock.createMock(QueryWatcher.class),
+        new DefaultObjectMapper(),
+        EasyMock.createMock(HttpClient.class),
+        baseView,
+        new HighestPriorityTierSelectorStrategy(new RandomServerSelectorStrategy()),
+        new NoopServiceEmitter(),
+        new BrokerSegmentWatcherConfig()
+    );
+  }
+
+  private static DruidServer newServer(String name)
+  {
+    return new DruidServer(
+        name,
+        "host:8083",
+        "host:8283",
+        1000L,
+        ServerType.HISTORICAL,
+        "tier",
+        0
+    );
+  }
+
+  private DataSegment newSegment(int partitionId)
+  {
+    return new DataSegment(
+        DATASOURCE,
+        Intervals.of("2012/2013"),
+        "version1",
+        null,
+        ImmutableList.of(),
+        ImmutableList.of(),
+        new NumberedShardSpec(partitionId, 0),
+        null,
+        1,
+        100L,
+        PruneSpecsHolder.DEFAULT
+    );
+  }
+
+  private QueryableIndex newQueryableIndex(int partitionId)
+  {
+    return IndexBuilder.create()
+                       .tmpDir(new File(tmpDir, "" + partitionId))
+                       .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
+                       .schema(
+                           new IncrementalIndexSchema.Builder()
+                               .withMetrics(
+                                   new CountAggregatorFactory("cnt"),
+                                   new DoubleSumAggregatorFactory("m1", "m1")
+                               )
+                               .withRollup(false)
+                               .build()
+                       )
+                       .rows(ROWS1)
+                       .buildMMappedIndex();
+  }
+
+  private static class TestServerInventoryView implements FilteredServerInventoryView
+  {
+    private final Map<String, DruidServer> serverMap = new HashMap<>();
+    private final Map<String, Set<DataSegment>> segmentsMap = new HashMap<>();
+    private final List<NonnullPair<SegmentCallback, Executor>> segmentCallbacks = new ArrayList<>();
+    private final List<NonnullPair<ServerRemovedCallback, Executor>> serverRemovedCallbacks = new ArrayList<>();
+
+    private void init()
+    {
+      segmentCallbacks.forEach(pair -> pair.rhs.execute(pair.lhs::segmentViewInitialized));
+    }
+
+    private void addServerSegment(DruidServer server, DataSegment segment)
+    {
+      serverMap.put(server.getName(), server);
+      segmentsMap.computeIfAbsent(server.getName(), k -> new HashSet<>()).add(segment);
+      segmentCallbacks.forEach(pair -> pair.rhs.execute(() -> pair.lhs.segmentAdded(server.getMetadata(), segment)));
+    }
+
+    private void removeServerSegment(DruidServer server, DataSegment segment)
+    {
+      segmentsMap.computeIfAbsent(server.getName(), k -> new HashSet<>()).remove(segment);
+      segmentCallbacks.forEach(pair -> pair.rhs.execute(() -> pair.lhs.segmentRemoved(server.getMetadata(), segment)));
+    }
+
+    private void removeServer(DruidServer server)
+    {
+      serverMap.remove(server.getName());
+      segmentsMap.remove(server.getName());
+      serverRemovedCallbacks.forEach(pair -> pair.rhs.execute(() -> pair.lhs.serverRemoved(server)));
+    }
+
+    @Override
+    public void registerSegmentCallback(
+        Executor exec,
+        SegmentCallback callback,
+        Predicate<Pair<DruidServerMetadata, DataSegment>> filter
+    )
+    {
+      segmentCallbacks.add(new NonnullPair<>(new FilteringSegmentCallback(callback, filter), exec));
+    }
+
+    @Override
+    public void registerServerRemovedCallback(Executor exec, ServerRemovedCallback callback)
+    {
+      serverRemovedCallbacks.add(new NonnullPair<>(callback, exec));
+    }
+
+    @Nullable
+    @Override
+    public DruidServer getInventoryValue(String serverKey)
+    {
+      return serverMap.get(serverKey);
+    }
+
+    @Override
+    public Collection<DruidServer> getInventory()
+    {
+      return serverMap.values();
+    }
+
+    @Override
+    public boolean isStarted()
+    {
+      return true;
+    }
+
+    @Override
+    public boolean isSegmentLoadedByServer(String serverKey, DataSegment segment)
+    {
+      Set<DataSegment> segments = segmentsMap.get(serverKey);
+      return segments != null && segments.contains(segment);
+    }
+  }
+}
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTest.java
index f66bfc0..e2a7554 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTest.java
@@ -29,13 +29,9 @@ import org.apache.calcite.rel.type.RelDataTypeField;
 import org.apache.calcite.schema.Table;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.druid.client.ImmutableDruidServer;
-import org.apache.druid.data.input.InputRow;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.Pair;
-import org.apache.druid.java.util.common.io.Closer;
-import org.apache.druid.query.DataSource;
 import org.apache.druid.query.GlobalTableDataSource;
-import org.apache.druid.query.QueryRunnerFactoryConglomerate;
 import org.apache.druid.query.TableDataSource;
 import org.apache.druid.query.aggregation.CountAggregatorFactory;
 import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
@@ -44,20 +40,12 @@ import org.apache.druid.query.aggregation.hyperloglog.HyperUniquesAggregatorFact
 import org.apache.druid.segment.IndexBuilder;
 import org.apache.druid.segment.QueryableIndex;
 import org.apache.druid.segment.incremental.IncrementalIndexSchema;
-import org.apache.druid.segment.join.JoinConditionAnalysis;
-import org.apache.druid.segment.join.Joinable;
-import org.apache.druid.segment.join.JoinableFactory;
 import org.apache.druid.segment.join.MapJoinableFactory;
-import org.apache.druid.segment.loading.SegmentLoader;
 import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
-import org.apache.druid.server.QueryStackTests;
-import org.apache.druid.server.SegmentManager;
 import org.apache.druid.server.coordination.DruidServerMetadata;
 import org.apache.druid.server.coordination.ServerType;
 import org.apache.druid.server.security.NoopEscalator;
-import org.apache.druid.sql.calcite.planner.PlannerConfig;
 import org.apache.druid.sql.calcite.table.DruidTable;
-import org.apache.druid.sql.calcite.util.CalciteTestBase;
 import org.apache.druid.sql.calcite.util.CalciteTests;
 import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
 import org.apache.druid.sql.calcite.util.TestServerInventoryView;
@@ -66,87 +54,33 @@ import org.apache.druid.timeline.DataSegment.PruneSpecsHolder;
 import org.apache.druid.timeline.SegmentId;
 import org.apache.druid.timeline.partition.LinearShardSpec;
 import org.apache.druid.timeline.partition.NumberedShardSpec;
-import org.easymock.EasyMock;
-import org.joda.time.Period;
 import org.junit.After;
-import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.Before;
-import org.junit.BeforeClass;
-import org.junit.Rule;
 import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
 
 import java.io.File;
 import java.io.IOException;
 import java.util.List;
 import java.util.Map;
-import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 
-public class DruidSchemaTest extends CalciteTestBase
+public class DruidSchemaTest extends DruidSchemaTestCommon
 {
-  private static final PlannerConfig PLANNER_CONFIG_DEFAULT = new PlannerConfig()
-  {
-    @Override
-    public Period getMetadataRefreshPeriod()
-    {
-      return new Period("PT1S");
-    }
-  };
-
-  private static final List<InputRow> ROWS1 = ImmutableList.of(
-      CalciteTests.createRow(ImmutableMap.of("t", "2000-01-01", "m1", "1.0", "dim1", "")),
-      CalciteTests.createRow(ImmutableMap.of("t", "2000-01-02", "m1", "2.0", "dim1", "10.1")),
-      CalciteTests.createRow(ImmutableMap.of("t", "2000-01-03", "m1", "3.0", "dim1", "2"))
-  );
-
-  private static final List<InputRow> ROWS2 = ImmutableList.of(
-      CalciteTests.createRow(ImmutableMap.of("t", "2001-01-01", "m1", "4.0", "dim2", ImmutableList.of("a"))),
-      CalciteTests.createRow(ImmutableMap.of("t", "2001-01-02", "m1", "5.0", "dim2", ImmutableList.of("abc"))),
-      CalciteTests.createRow(ImmutableMap.of("t", "2001-01-03", "m1", "6.0"))
-  );
-
-  private static QueryRunnerFactoryConglomerate conglomerate;
-  private static Closer resourceCloser;
-
+  private SpecificSegmentsQuerySegmentWalker walker = null;
   private TestServerInventoryView serverView;
   private List<ImmutableDruidServer> druidServers;
-  private CountDownLatch getDatasourcesLatch = new CountDownLatch(1);
-  private CountDownLatch buildTableLatch = new CountDownLatch(1);
-
-  @BeforeClass
-  public static void setUpClass()
-  {
-    resourceCloser = Closer.create();
-    conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
-  }
-
-  @AfterClass
-  public static void tearDownClass() throws IOException
-  {
-    resourceCloser.close();
-  }
-
-  @Rule
-  public TemporaryFolder temporaryFolder = new TemporaryFolder();
-
-  private SpecificSegmentsQuerySegmentWalker walker = null;
   private DruidSchema schema = null;
   private DruidSchema schema2 = null;
-  private SegmentManager segmentManager;
-  private Set<String> segmentDataSourceNames;
-  private Set<String> joinableDataSourceNames;
+  private CountDownLatch buildTableLatch = new CountDownLatch(1);
+  private CountDownLatch markDataSourceLatch = new CountDownLatch(1);
 
   @Before
   public void setUp() throws Exception
   {
-    segmentDataSourceNames = Sets.newConcurrentHashSet();
-    joinableDataSourceNames = Sets.newConcurrentHashSet();
-
     final File tmpDir = temporaryFolder.newFolder();
     final QueryableIndex index1 = IndexBuilder.create()
                                               .tmpDir(new File(tmpDir, "1"))
@@ -175,17 +109,6 @@ public class DruidSchemaTest extends CalciteTestBase
                                               )
                                               .rows(ROWS2)
                                               .buildMMappedIndex();
-
-    segmentManager = new SegmentManager(EasyMock.createMock(SegmentLoader.class))
-    {
-      @Override
-      public Set<String> getDataSourceNames()
-      {
-        getDatasourcesLatch.countDown();
-        return segmentDataSourceNames;
-      }
-    };
-
     walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
         DataSegment.builder()
                    .dataSource(CalciteTests.DATASOURCE1)
@@ -231,25 +154,6 @@ public class DruidSchemaTest extends CalciteTestBase
     serverView = new TestServerInventoryView(walker.getSegments(), realtimeSegments);
     druidServers = serverView.getDruidServers();
 
-    final JoinableFactory globalTableJoinable = new JoinableFactory()
-    {
-      @Override
-      public boolean isDirectlyJoinable(DataSource dataSource)
-      {
-        return dataSource instanceof GlobalTableDataSource &&
-               joinableDataSourceNames.contains(((GlobalTableDataSource) dataSource).getName());
-      }
-
-      @Override
-      public Optional<Joinable> build(
-          DataSource dataSource,
-          JoinConditionAnalysis condition
-      )
-      {
-        return Optional.empty();
-      }
-    };
-
     schema = new DruidSchema(
         CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
         serverView,
@@ -266,15 +170,22 @@ public class DruidSchemaTest extends CalciteTestBase
         buildTableLatch.countDown();
         return table;
       }
+
+      @Override
+      void markDataSourceAsNeedRebuild(String datasource)
+      {
+        super.markDataSourceAsNeedRebuild(datasource);
+        markDataSourceLatch.countDown();
+      }
     };
 
     schema2 = new DruidSchema(
-            CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
-            serverView,
-            segmentManager,
-            new MapJoinableFactory(ImmutableSet.of(globalTableJoinable), ImmutableMap.of(globalTableJoinable.getClass(), GlobalTableDataSource.class)),
-            PLANNER_CONFIG_DEFAULT,
-            new NoopEscalator()
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(globalTableJoinable), ImmutableMap.of(globalTableJoinable.getClass(), GlobalTableDataSource.class)),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
     )
     {
 
@@ -297,6 +208,13 @@ public class DruidSchemaTest extends CalciteTestBase
           return super.refreshSegments(segments);
         }
       }
+
+      @Override
+      void markDataSourceAsNeedRebuild(String datasource)
+      {
+        super.markDataSourceAsNeedRebuild(datasource);
+        markDataSourceLatch.countDown();
+      }
     };
 
     schema.start();
@@ -533,6 +451,422 @@ public class DruidSchemaTest extends CalciteTestBase
   }
 
   @Test
+  public void testSegmentAddedCallbackAddNewHistoricalSegment() throws InterruptedException
+  {
+    String datasource = "newSegmentAddTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+    };
+
+    serverView.addSegment(newSegment(datasource, 1), ServerType.HISTORICAL);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(5, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(1, metadatas.size());
+    AvailableSegmentMetadata metadata = metadatas.get(0);
+    Assert.assertEquals(0, metadata.isRealtime());
+    Assert.assertEquals(0, metadata.getNumRows());
+    Assert.assertTrue(schema.getSegmentsNeedingRefresh().contains(metadata.getSegment().getId()));
+  }
+
+  @Test
+  public void testSegmentAddedCallbackAddExistingSegment() throws InterruptedException
+  {
+    String datasource = "newSegmentAddTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(2);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+    };
+
+    DataSegment segment = newSegment(datasource, 1);
+    serverView.addSegment(segment, ServerType.REALTIME);
+    serverView.addSegment(segment, ServerType.HISTORICAL);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(5, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(1, metadatas.size());
+    AvailableSegmentMetadata metadata = metadatas.get(0);
+    Assert.assertEquals(0, metadata.isRealtime()); // realtime flag is unset when there is any historical
+    Assert.assertEquals(0, metadata.getNumRows());
+    Assert.assertEquals(2, metadata.getNumReplicas());
+    Assert.assertTrue(schema.getSegmentsNeedingRefresh().contains(metadata.getSegment().getId()));
+    Assert.assertFalse(schema.getMutableSegments().contains(metadata.getSegment().getId()));
+  }
+
+  @Test
+  public void testSegmentAddedCallbackAddNewRealtimeSegment() throws InterruptedException
+  {
+    String datasource = "newSegmentAddTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+    };
+
+    serverView.addSegment(newSegment(datasource, 1), ServerType.REALTIME);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(5, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(1, metadatas.size());
+    AvailableSegmentMetadata metadata = metadatas.get(0);
+    Assert.assertEquals(1, metadata.isRealtime());
+    Assert.assertEquals(0, metadata.getNumRows());
+    Assert.assertTrue(schema.getSegmentsNeedingRefresh().contains(metadata.getSegment().getId()));
+    Assert.assertTrue(schema.getMutableSegments().contains(metadata.getSegment().getId()));
+  }
+
+  @Test
+  public void testSegmentAddedCallbackAddNewBroadcastSegment() throws InterruptedException
+  {
+    String datasource = "newSegmentAddTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+    };
+
+    serverView.addSegment(newSegment(datasource, 1), ServerType.BROKER);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(4, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(0, metadatas.size());
+    Assert.assertTrue(schema.getDataSourcesNeedingRebuild().contains(datasource));
+  }
+
+  @Test
+  public void testSegmentRemovedCallbackEmptyDataSourceAfterRemove() throws InterruptedException, IOException
+  {
+    String datasource = "segmentRemoveTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(1);
+    CountDownLatch removeSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+
+      @Override
+      void removeSegment(final DataSegment segment)
+      {
+        super.removeSegment(segment);
+        if (datasource.equals(segment.getDataSource())) {
+          removeSegmentLatch.countDown();
+        }
+      }
+    };
+
+    DataSegment segment = newSegment(datasource, 1);
+    serverView.addSegment(segment, ServerType.REALTIME);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+    schema.refresh(Sets.newHashSet(segment.getId()), Sets.newHashSet(datasource));
+
+    serverView.removeSegment(segment, ServerType.REALTIME);
+    Assert.assertTrue(removeSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(4, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(0, metadatas.size());
+    Assert.assertFalse(schema.getSegmentsNeedingRefresh().contains(segment.getId()));
+    Assert.assertFalse(schema.getMutableSegments().contains(segment.getId()));
+    Assert.assertFalse(schema.getDataSourcesNeedingRebuild().contains(datasource));
+    Assert.assertFalse(schema.getTableNames().contains(datasource));
+  }
+
+  @Test
+  public void testSegmentRemovedCallbackNonEmptyDataSourceAfterRemove() throws InterruptedException, IOException
+  {
+    String datasource = "segmentRemoveTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(2);
+    CountDownLatch removeSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+
+      @Override
+      void removeSegment(final DataSegment segment)
+      {
+        super.removeSegment(segment);
+        if (datasource.equals(segment.getDataSource())) {
+          removeSegmentLatch.countDown();
+        }
+      }
+    };
+
+    List<DataSegment> segments = ImmutableList.of(
+        newSegment(datasource, 1),
+        newSegment(datasource, 2)
+    );
+    serverView.addSegment(segments.get(0), ServerType.REALTIME);
+    serverView.addSegment(segments.get(1), ServerType.HISTORICAL);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+    schema.refresh(segments.stream().map(DataSegment::getId).collect(Collectors.toSet()), Sets.newHashSet(datasource));
+
+    serverView.removeSegment(segments.get(0), ServerType.REALTIME);
+    Assert.assertTrue(removeSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(5, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(1, metadatas.size());
+    Assert.assertFalse(schema.getSegmentsNeedingRefresh().contains(segments.get(0).getId()));
+    Assert.assertFalse(schema.getMutableSegments().contains(segments.get(0).getId()));
+    Assert.assertTrue(schema.getDataSourcesNeedingRebuild().contains(datasource));
+    Assert.assertTrue(schema.getTableNames().contains(datasource));
+  }
+
+  @Test
+  public void testServerSegmentRemovedCallbackRemoveUnknownSegment() throws InterruptedException
+  {
+    String datasource = "serverSegmentRemoveTest";
+    CountDownLatch removeServerSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void removeServerSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.removeServerSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          removeServerSegmentLatch.countDown();
+        }
+      }
+    };
+
+    serverView.addSegment(newSegment(datasource, 1), ServerType.BROKER);
+
+    serverView.removeSegment(newSegment(datasource, 1), ServerType.HISTORICAL);
+    Assert.assertTrue(removeServerSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(4, schema.getTotalSegments());
+  }
+
+  @Test
+  public void testServerSegmentRemovedCallbackRemoveBrokerSegment() throws InterruptedException
+  {
+    String datasource = "serverSegmentRemoveTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(1);
+    CountDownLatch removeServerSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+
+      @Override
+      void removeServerSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.removeServerSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          removeServerSegmentLatch.countDown();
+        }
+      }
+    };
+
+    DataSegment segment = newSegment(datasource, 1);
+    serverView.addSegment(segment, ServerType.HISTORICAL);
+    serverView.addSegment(segment, ServerType.BROKER);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    serverView.removeSegment(segment, ServerType.BROKER);
+    Assert.assertTrue(removeServerSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(5, schema.getTotalSegments());
+    Assert.assertTrue(schema.getDataSourcesNeedingRebuild().contains(datasource));
+  }
+
+  @Test
+  public void testServerSegmentRemovedCallbackRemoveHistoricalSegment() throws InterruptedException
+  {
+    String datasource = "serverSegmentRemoveTest";
+    CountDownLatch addSegmentLatch = new CountDownLatch(1);
+    CountDownLatch removeServerSegmentLatch = new CountDownLatch(1);
+    DruidSchema schema = new DruidSchema(
+        CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+        serverView,
+        segmentManager,
+        new MapJoinableFactory(ImmutableSet.of(), ImmutableMap.of()),
+        PLANNER_CONFIG_DEFAULT,
+        new NoopEscalator()
+    )
+    {
+      @Override
+      void addSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.addSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          addSegmentLatch.countDown();
+        }
+      }
+
+      @Override
+      void removeServerSegment(final DruidServerMetadata server, final DataSegment segment)
+      {
+        super.removeServerSegment(server, segment);
+        if (datasource.equals(segment.getDataSource())) {
+          removeServerSegmentLatch.countDown();
+        }
+      }
+    };
+
+    DataSegment segment = newSegment(datasource, 1);
+    serverView.addSegment(segment, ServerType.HISTORICAL);
+    serverView.addSegment(segment, ServerType.BROKER);
+    Assert.assertTrue(addSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    serverView.removeSegment(segment, ServerType.HISTORICAL);
+    Assert.assertTrue(removeServerSegmentLatch.await(1, TimeUnit.SECONDS));
+
+    Assert.assertEquals(5, schema.getTotalSegments());
+    List<AvailableSegmentMetadata> metadatas = schema
+        .getSegmentMetadataSnapshot()
+        .values()
+        .stream()
+        .filter(metadata -> datasource.equals(metadata.getSegment().getDataSource()))
+        .collect(Collectors.toList());
+    Assert.assertEquals(1, metadatas.size());
+    AvailableSegmentMetadata metadata = metadatas.get(0);
+    Assert.assertEquals(0, metadata.isRealtime());
+    Assert.assertEquals(0, metadata.getNumRows());
+    Assert.assertEquals(0, metadata.getNumReplicas()); // brokers are not counted as replicas yet
+  }
+
+  @Test
   public void testLocalSegmentCacheSetsDataSourceAsGlobalAndJoinable() throws InterruptedException
   {
     DruidTable fooTable = (DruidTable) schema.getTableMap().get("foo");
@@ -542,8 +876,9 @@ public class DruidSchemaTest extends CalciteTestBase
     Assert.assertFalse(fooTable.isJoinable());
     Assert.assertFalse(fooTable.isBroadcast());
 
-    buildTableLatch.await(1, TimeUnit.SECONDS);
+    Assert.assertTrue(buildTableLatch.await(1, TimeUnit.SECONDS));
 
+    buildTableLatch = new CountDownLatch(1);
     final DataSegment someNewBrokerSegment = new DataSegment(
         "foo",
         Intervals.of("2012/2013"),
@@ -560,14 +895,11 @@ public class DruidSchemaTest extends CalciteTestBase
     segmentDataSourceNames.add("foo");
     joinableDataSourceNames.add("foo");
     serverView.addSegment(someNewBrokerSegment, ServerType.BROKER);
-
+    Assert.assertTrue(markDataSourceLatch.await(2, TimeUnit.SECONDS));
     // wait for build twice
-    buildTableLatch = new CountDownLatch(2);
-    buildTableLatch.await(1, TimeUnit.SECONDS);
-
+    Assert.assertTrue(buildTableLatch.await(2, TimeUnit.SECONDS));
     // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated)
-    getDatasourcesLatch = new CountDownLatch(1);
-    getDatasourcesLatch.await(1, TimeUnit.SECONDS);
+    Assert.assertTrue(getDatasourcesLatch.await(2, TimeUnit.SECONDS));
 
     fooTable = (DruidTable) schema.getTableMap().get("foo");
     Assert.assertNotNull(fooTable);
@@ -577,18 +909,18 @@ public class DruidSchemaTest extends CalciteTestBase
     Assert.assertTrue(fooTable.isBroadcast());
 
     // now remove it
+    markDataSourceLatch = new CountDownLatch(1);
+    buildTableLatch = new CountDownLatch(1);
+    getDatasourcesLatch = new CountDownLatch(1);
     joinableDataSourceNames.remove("foo");
     segmentDataSourceNames.remove("foo");
     serverView.removeSegment(someNewBrokerSegment, ServerType.BROKER);
 
+    Assert.assertTrue(markDataSourceLatch.await(2, TimeUnit.SECONDS));
     // wait for build
-    buildTableLatch.await(1, TimeUnit.SECONDS);
-    buildTableLatch = new CountDownLatch(1);
-    buildTableLatch.await(1, TimeUnit.SECONDS);
-
+    Assert.assertTrue(buildTableLatch.await(2, TimeUnit.SECONDS));
     // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated)
-    getDatasourcesLatch = new CountDownLatch(1);
-    getDatasourcesLatch.await(1, TimeUnit.SECONDS);
+    Assert.assertTrue(getDatasourcesLatch.await(2, TimeUnit.SECONDS));
 
     fooTable = (DruidTable) schema.getTableMap().get("foo");
     Assert.assertNotNull(fooTable);
@@ -609,8 +941,9 @@ public class DruidSchemaTest extends CalciteTestBase
     Assert.assertFalse(fooTable.isBroadcast());
 
     // wait for build twice
-    buildTableLatch.await(1, TimeUnit.SECONDS);
+    Assert.assertTrue(buildTableLatch.await(1, TimeUnit.SECONDS));
 
+    buildTableLatch = new CountDownLatch(1);
     final DataSegment someNewBrokerSegment = new DataSegment(
         "foo",
         Intervals.of("2012/2013"),
@@ -627,12 +960,10 @@ public class DruidSchemaTest extends CalciteTestBase
     segmentDataSourceNames.add("foo");
     serverView.addSegment(someNewBrokerSegment, ServerType.BROKER);
 
-    buildTableLatch = new CountDownLatch(2);
-    buildTableLatch.await(1, TimeUnit.SECONDS);
-
+    Assert.assertTrue(markDataSourceLatch.await(2, TimeUnit.SECONDS));
+    Assert.assertTrue(buildTableLatch.await(2, TimeUnit.SECONDS));
     // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated)
-    getDatasourcesLatch = new CountDownLatch(1);
-    getDatasourcesLatch.await(1, TimeUnit.SECONDS);
+    Assert.assertTrue(getDatasourcesLatch.await(2, TimeUnit.SECONDS));
 
     fooTable = (DruidTable) schema.getTableMap().get("foo");
     Assert.assertNotNull(fooTable);
@@ -643,19 +974,18 @@ public class DruidSchemaTest extends CalciteTestBase
     Assert.assertTrue(fooTable.isBroadcast());
     Assert.assertFalse(fooTable.isJoinable());
 
-
     // now remove it
+    markDataSourceLatch = new CountDownLatch(1);
+    buildTableLatch = new CountDownLatch(1);
+    getDatasourcesLatch = new CountDownLatch(1);
     segmentDataSourceNames.remove("foo");
     serverView.removeSegment(someNewBrokerSegment, ServerType.BROKER);
 
+    Assert.assertTrue(markDataSourceLatch.await(2, TimeUnit.SECONDS));
     // wait for build
-    buildTableLatch.await(1, TimeUnit.SECONDS);
-    buildTableLatch = new CountDownLatch(1);
-    buildTableLatch.await(1, TimeUnit.SECONDS);
-
+    Assert.assertTrue(buildTableLatch.await(2, TimeUnit.SECONDS));
     // wait for get again, just to make sure table has been updated (latch counts down just before tables are updated)
-    getDatasourcesLatch = new CountDownLatch(1);
-    getDatasourcesLatch.await(1, TimeUnit.SECONDS);
+    Assert.assertTrue(getDatasourcesLatch.await(2, TimeUnit.SECONDS));
 
     fooTable = (DruidTable) schema.getTableMap().get("foo");
     Assert.assertNotNull(fooTable);
@@ -664,4 +994,21 @@ public class DruidSchemaTest extends CalciteTestBase
     Assert.assertFalse(fooTable.isBroadcast());
     Assert.assertFalse(fooTable.isJoinable());
   }
+
+  private static DataSegment newSegment(String datasource, int partitionId)
+  {
+    return new DataSegment(
+        datasource,
+        Intervals.of("2012/2013"),
+        "version1",
+        null,
+        ImmutableList.of("dim1", "dim2"),
+        ImmutableList.of("met1", "met2"),
+        new NumberedShardSpec(partitionId, 0),
+        null,
+        1,
+        100L,
+        PruneSpecsHolder.DEFAULT
+    );
+  }
 }
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTestCommon.java b/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTestCommon.java
new file mode 100644
index 0000000..2511d97
--- /dev/null
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/schema/DruidSchemaTestCommon.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.sql.calcite.schema;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Sets;
+import org.apache.druid.data.input.InputRow;
+import org.apache.druid.java.util.common.io.Closer;
+import org.apache.druid.query.DataSource;
+import org.apache.druid.query.GlobalTableDataSource;
+import org.apache.druid.query.QueryRunnerFactoryConglomerate;
+import org.apache.druid.segment.join.JoinConditionAnalysis;
+import org.apache.druid.segment.join.Joinable;
+import org.apache.druid.segment.join.JoinableFactory;
+import org.apache.druid.segment.loading.SegmentLoader;
+import org.apache.druid.server.QueryStackTests;
+import org.apache.druid.server.SegmentManager;
+import org.apache.druid.sql.calcite.planner.PlannerConfig;
+import org.apache.druid.sql.calcite.util.CalciteTestBase;
+import org.apache.druid.sql.calcite.util.CalciteTests;
+import org.easymock.EasyMock;
+import org.joda.time.Period;
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+
+public abstract class DruidSchemaTestCommon extends CalciteTestBase
+{
+  static final PlannerConfig PLANNER_CONFIG_DEFAULT = new PlannerConfig()
+  {
+    @Override
+    public Period getMetadataRefreshPeriod()
+    {
+      return new Period("PT1S");
+    }
+  };
+
+  static final List<InputRow> ROWS1 = ImmutableList.of(
+      CalciteTests.createRow(ImmutableMap.of("t", "2000-01-01", "m1", "1.0", "dim1", "")),
+      CalciteTests.createRow(ImmutableMap.of("t", "2000-01-02", "m1", "2.0", "dim1", "10.1")),
+      CalciteTests.createRow(ImmutableMap.of("t", "2000-01-03", "m1", "3.0", "dim1", "2"))
+  );
+
+  static final List<InputRow> ROWS2 = ImmutableList.of(
+      CalciteTests.createRow(ImmutableMap.of("t", "2001-01-01", "m1", "4.0", "dim2", ImmutableList.of("a"))),
+      CalciteTests.createRow(ImmutableMap.of("t", "2001-01-02", "m1", "5.0", "dim2", ImmutableList.of("abc"))),
+      CalciteTests.createRow(ImmutableMap.of("t", "2001-01-03", "m1", "6.0"))
+  );
+
+  static QueryRunnerFactoryConglomerate conglomerate;
+  static Closer resourceCloser;
+
+  CountDownLatch getDatasourcesLatch = new CountDownLatch(1);
+
+  @BeforeClass
+  public static void setUpClass()
+  {
+    resourceCloser = Closer.create();
+    conglomerate = QueryStackTests.createQueryRunnerFactoryConglomerate(resourceCloser);
+  }
+
+  @AfterClass
+  public static void tearDownClass() throws IOException
+  {
+    resourceCloser.close();
+  }
+
+  @Rule
+  public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+  SegmentManager segmentManager;
+  Set<String> segmentDataSourceNames;
+  Set<String> joinableDataSourceNames;
+  JoinableFactory globalTableJoinable;
+
+  @Before
+  public void setUpCommon()
+  {
+    segmentDataSourceNames = Sets.newConcurrentHashSet();
+    joinableDataSourceNames = Sets.newConcurrentHashSet();
+
+    segmentManager = new SegmentManager(EasyMock.createMock(SegmentLoader.class))
+    {
+      @Override
+      public Set<String> getDataSourceNames()
+      {
+        getDatasourcesLatch.countDown();
+        return segmentDataSourceNames;
+      }
+    };
+
+    globalTableJoinable = new JoinableFactory()
+    {
+      @Override
+      public boolean isDirectlyJoinable(DataSource dataSource)
+      {
+        return dataSource instanceof GlobalTableDataSource &&
+               joinableDataSourceNames.contains(((GlobalTableDataSource) dataSource).getName());
+      }
+
+      @Override
+      public Optional<Joinable> build(
+          DataSource dataSource,
+          JoinConditionAnalysis condition
+      )
+      {
+        return Optional.empty();
+      }
+    };
+  }
+}

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org