You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by kf...@apache.org on 2023/05/30 03:23:02 UTC

[druid] branch master updated: Add tests for CostBalancerStrategy (#14230)

This is an automated email from the ASF dual-hosted git repository.

kfaraz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new d4cacebf79 Add tests for CostBalancerStrategy (#14230)
d4cacebf79 is described below

commit d4cacebf79c9795aa36b5e0776df44ed735b93d0
Author: Kashif Faraz <ka...@gmail.com>
AuthorDate: Tue May 30 08:52:56 2023 +0530

    Add tests for CostBalancerStrategy (#14230)
    
    Changes:
    - `CostBalancerStrategyTest`
      - Focus on verification of cost computations rather than choosing servers in this test
      - Add new tests `testComputeCost` and `testJointSegmentsCost`
      - Add tests to demonstrate that with a long enough interval gap, all costs become negligible
      - Retain `testIntervalCost` and `testIntervalCostAdditivity`
      - Remove redundant tests such as `testStrategyMultiThreaded`, `testStrategySingleThreaded`as
    verification of this behaviour is better suited to `BalancingStrategiesTest`.
    - `CostBalancerStrategyBenchmark`
      - Remove usage of static method from `CostBalancerStrategyTest`
      - Explicitly setup cluster and segments to use for benchmarking
---
 .../coordinator/CostBalancerStrategyBenchmark.java |  52 ++-
 .../coordinator/CostBalancerStrategyTest.java      | 438 +++++++++++++--------
 .../server/coordinator/CreateDataSegments.java     |  24 +-
 3 files changed, 335 insertions(+), 179 deletions(-)

diff --git a/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyBenchmark.java b/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyBenchmark.java
index 881698f7a2..e5ce98f7af 100644
--- a/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyBenchmark.java
+++ b/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyBenchmark.java
@@ -22,19 +22,23 @@ package org.apache.druid.server.coordinator;
 import com.carrotsearch.junitbenchmarks.AbstractBenchmark;
 import com.carrotsearch.junitbenchmarks.BenchmarkOptions;
 import com.google.common.util.concurrent.MoreExecutors;
+import org.apache.druid.client.DruidServer;
 import org.apache.druid.java.util.common.Intervals;
 import org.apache.druid.java.util.common.concurrent.Execs;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.server.coordination.ServerType;
 import org.apache.druid.timeline.DataSegment;
 import org.joda.time.Interval;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
 import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
 
 @Ignore
 @RunWith(Parameterized.class)
@@ -56,36 +60,54 @@ public class CostBalancerStrategyBenchmark extends AbstractBenchmark
   }
 
   private final CostBalancerStrategy strategy;
+  private final List<ServerHolder> serverHolderList;
 
   public CostBalancerStrategyBenchmark(CostBalancerStrategy costBalancerStrategy)
   {
     this.strategy = costBalancerStrategy;
+    this.serverHolderList = initServers();
   }
 
-  private static List<ServerHolder> serverHolderList;
-  volatile ServerHolder selected;
-
-  @BeforeClass
-  public static void setup()
+  private List<ServerHolder> initServers()
   {
-    serverHolderList = CostBalancerStrategyTest.setupDummyCluster(5, 20000);
-  }
+    final List<DruidServer> servers = new ArrayList<>();
+    for (int i = 0; i < 6; ++i) {
+      DruidServer druidServer = new DruidServer(
+          "server_" + i,
+          "localhost", null, 10_000_000L, ServerType.HISTORICAL, "hot", 1
+      );
+      servers.add(druidServer);
+    }
 
-  @AfterClass
-  public static void tearDown()
-  {
-    serverHolderList = null;
+    // Create and randomly distribute some segments amongst the servers
+    final List<DataSegment> segments =
+        CreateDataSegments.ofDatasource("wikipedia")
+                          .forIntervals(200, Granularities.DAY)
+                          .withNumPartitions(100)
+                          .eachOfSizeInMb(200);
+    final Random random = new Random(100);
+    segments.forEach(
+        segment -> servers.get(random.nextInt(servers.size()))
+                          .addDataSegment(segment)
+    );
+
+    return servers.stream()
+                  .map(DruidServer::toImmutableDruidServer)
+                  .map(server -> new ServerHolder(server, null))
+                  .collect(Collectors.toList());
   }
 
+  volatile ServerHolder selected;
+
   @Test
   @BenchmarkOptions(warmupRounds = 10, benchmarkRounds = 1000)
   public void testBenchmark()
   {
-    DataSegment segment = CostBalancerStrategyTest.getSegment(1000, "testds", interval1);
+    DataSegment segment = DataSegment.builder().dataSource("testds").version("1000")
+                                     .interval(interval1).size(100L).build();
     selected = strategy.findNewSegmentHomeReplicator(segment, serverHolderList);
   }
 
-
   // Benchmark Joda Interval Gap impl vs CostBalancer.gapMillis
   private final Interval interval1 = Intervals.of("2015-01-01T01:00:00Z/2015-01-01T02:00:00Z");
   private final Interval interval2 = Intervals.of("2015-02-01T01:00:00Z/2015-02-01T02:00:00Z");
diff --git a/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyTest.java b/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyTest.java
index 2e2d2a44f9..1b6a0cfdc6 100644
--- a/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyTest.java
+++ b/server/src/test/java/org/apache/druid/server/coordinator/CostBalancerStrategyTest.java
@@ -19,211 +19,329 @@
 
 package org.apache.druid.server.coordinator;
 
-import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.MoreExecutors;
-import org.apache.druid.client.ImmutableDruidDataSource;
-import org.apache.druid.client.ImmutableDruidServer;
-import org.apache.druid.client.ImmutableDruidServerTests;
-import org.apache.druid.java.util.common.DateTimes;
-import org.apache.druid.java.util.common.Intervals;
-import org.apache.druid.java.util.common.concurrent.Execs;
-import org.apache.druid.server.coordination.DruidServerMetadata;
+import org.apache.druid.client.DruidServer;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.granularity.GranularityType;
 import org.apache.druid.server.coordination.ServerType;
+import org.apache.druid.server.coordinator.simulate.BlockingExecutorService;
 import org.apache.druid.timeline.DataSegment;
-import org.apache.druid.timeline.SegmentId;
-import org.easymock.EasyMock;
-import org.joda.time.DateTime;
-import org.joda.time.Interval;
+import org.junit.After;
 import org.junit.Assert;
+import org.junit.Before;
 import org.junit.Test;
 
 import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.Random;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.TimeUnit;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 public class CostBalancerStrategyTest
 {
-  private static final Interval DAY = Intervals.of("2015-01-01T00/2015-01-01T01");
+  private static final double DELTA = 1e-6;
+  private static final String DS_WIKI = "wiki";
 
-  /**
-   * Create Druid cluster with serverCount servers having maxSegments segments each, and 1 server with 98 segment
-   * Cost Balancer Strategy should assign the next segment to the server with less segments.
-   */
-  public static List<ServerHolder> setupDummyCluster(int serverCount, int maxSegments)
+  private ExecutorService balancerExecutor;
+  private CostBalancerStrategy strategy;
+  private int uniqueServerId;
+
+  @Before
+  public void setup()
   {
-    List<ServerHolder> serverHolderList = new ArrayList<>();
-    // Create 10 servers with current size being 3K & max size being 10K
-    // Each having having 100 segments
-    for (int i = 0; i < serverCount; i++) {
-      LoadQueuePeonTester fromPeon = new LoadQueuePeonTester();
-
-      List<DataSegment> segments = IntStream
-          .range(0, maxSegments)
-          .mapToObj(j -> getSegment(j))
-          .collect(Collectors.toList());
-      ImmutableDruidDataSource dataSource = new ImmutableDruidDataSource("DUMMY", Collections.emptyMap(), segments);
-
-      String serverName = "DruidServer_Name_" + i;
-      ServerHolder serverHolder = new ServerHolder(
-          new ImmutableDruidServer(
-              new DruidServerMetadata(serverName, "localhost", null, 10000000L, ServerType.HISTORICAL, "hot", 1),
-              3000L,
-              ImmutableMap.of("DUMMY", dataSource),
-              segments.size()
-          ),
-          fromPeon
-      );
-      serverHolderList.add(serverHolder);
-    }
+    balancerExecutor = new BlockingExecutorService("test-balance-exec-%d");
+    strategy = new CostBalancerStrategy(MoreExecutors.listeningDecorator(balancerExecutor));
+  }
 
-    // The best server to be available for next segment assignment has only 98 Segments
-    LoadQueuePeonTester fromPeon = new LoadQueuePeonTester();
-    ImmutableDruidServer druidServer = EasyMock.createMock(ImmutableDruidServer.class);
-    EasyMock.expect(druidServer.getName()).andReturn("BEST_SERVER").anyTimes();
-    EasyMock.expect(druidServer.getCurrSize()).andReturn(3000L).anyTimes();
-    EasyMock.expect(druidServer.getMaxSize()).andReturn(10000000L).anyTimes();
-
-    EasyMock.expect(druidServer.getSegment(EasyMock.anyObject())).andReturn(null).anyTimes();
-    Map<SegmentId, DataSegment> segments = new HashMap<>();
-    for (int j = 0; j < (maxSegments - 2); j++) {
-      DataSegment segment = getSegment(j);
-      segments.put(segment.getId(), segment);
-      EasyMock.expect(druidServer.getSegment(segment.getId())).andReturn(segment).anyTimes();
+  @After
+  public void tearDown()
+  {
+    if (balancerExecutor != null) {
+      balancerExecutor.shutdownNow();
     }
-    ImmutableDruidServerTests.expectSegments(druidServer, segments.values());
+  }
 
-    EasyMock.replay(druidServer);
-    serverHolderList.add(new ServerHolder(druidServer, fromPeon));
-    return serverHolderList;
+  @Test
+  public void testIntervalCostAdditivity()
+  {
+    Assert.assertEquals(
+        intervalCost(1, 1, 3),
+        intervalCost(1, 1, 2) + intervalCost(1, 2, 3),
+        DELTA
+    );
+
+    Assert.assertEquals(
+        intervalCost(2, 1, 3),
+        intervalCost(2, 1, 2) + intervalCost(2, 2, 3),
+        DELTA
+    );
+
+    Assert.assertEquals(
+        intervalCost(3, 1, 2),
+        intervalCost(1, 0, 1) + intervalCost(1, 1, 2) + intervalCost(1, 1, 2),
+        DELTA
+    );
   }
 
-  /**
-   * Returns segment with dummy id and size 100
-   *
-   * @param index
-   *
-   * @return segment
-   */
-  public static DataSegment getSegment(int index)
+  private double intervalCost(double x1, double y0, double y1)
   {
-    return getSegment(index, "DUMMY", DAY);
+    return CostBalancerStrategy.intervalCost(x1, y0, y1);
   }
 
-  public static DataSegment getSegment(int index, String dataSource, Interval interval)
+  @Test
+  public void testIntervalCost()
   {
-    // Not using EasyMock as it hampers the performance of multithreads.
-    DataSegment segment = new DataSegment(
-        dataSource,
-        interval,
-        String.valueOf(index),
-        new ConcurrentHashMap<>(),
-        new ArrayList<>(),
-        new ArrayList<>(),
-        null,
-        0,
-        index * 100L
+    // no overlap
+    // [0, 1) [1, 2)
+    Assert.assertEquals(0.3995764, intervalCost(1, 1, 2), DELTA);
+    // [0, 1) [-1, 0)
+    Assert.assertEquals(0.3995764, intervalCost(1, -1, 0), DELTA);
+
+    // exact overlap
+    // [0, 1), [0, 1)
+    Assert.assertEquals(0.7357589, intervalCost(1, 0, 1), DELTA);
+    // [0, 2), [0, 2)
+    Assert.assertEquals(2.270671, intervalCost(2, 0, 2), DELTA);
+
+    // partial overlap
+    // [0, 2), [1, 3)
+    Assert.assertEquals(1.681908, intervalCost(2, 1, 3), DELTA);
+    // [0, 2), [1, 2)
+    Assert.assertEquals(1.135335, intervalCost(2, 1, 2), DELTA);
+    // [0, 2), [0, 1)
+    Assert.assertEquals(1.135335, intervalCost(2, 0, 1), DELTA);
+    // [0, 3), [1, 2)
+    Assert.assertEquals(1.534912, intervalCost(3, 1, 2), DELTA);
+  }
+
+  @Test
+  public void testJointSegmentsCost()
+  {
+    final long noGap = 0;
+    final long oneDayGap = TimeUnit.DAYS.toMillis(1);
+    verifyJointSegmentsCost(GranularityType.HOUR, GranularityType.HOUR, noGap, 1.980884);
+    verifyJointSegmentsCost(GranularityType.HOUR, GranularityType.HOUR, oneDayGap, 1.000070);
+
+    verifyJointSegmentsCost(GranularityType.HOUR, GranularityType.DAY, noGap, 35.110275);
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.DAY, noGap, 926.232308);
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.DAY, oneDayGap, 599.434267);
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.DAY, 7 * oneDayGap, 9.366160);
+
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.MONTH, noGap, 2125.100840);
+    verifyJointSegmentsCost(GranularityType.MONTH, GranularityType.MONTH, noGap, 98247.576470);
+    verifyJointSegmentsCost(GranularityType.MONTH, GranularityType.MONTH, 7 * oneDayGap, 79719.068161);
+
+    verifyJointSegmentsCost(GranularityType.MONTH, GranularityType.YEAR, noGap, 100645.313535);
+    verifyJointSegmentsCost(GranularityType.YEAR, GranularityType.YEAR, noGap, 1208453.347454);
+    verifyJointSegmentsCost(GranularityType.YEAR, GranularityType.YEAR, 7 * oneDayGap, 1189943.571325);
+  }
+
+  @Test
+  public void testJointSegmentsCostSymmetry()
+  {
+    final DataSegment segmentA = CreateDataSegments.ofDatasource(DS_WIKI)
+                                                   .forIntervals(1, Granularities.DAY)
+                                                   .startingAt("2010-01-01")
+                                                   .eachOfSizeInMb(100).get(0);
+    final DataSegment segmentB = CreateDataSegments.ofDatasource(DS_WIKI)
+                                                   .forIntervals(1, Granularities.MONTH)
+                                                   .startingAt("2010-01-01")
+                                                   .eachOfSizeInMb(100).get(0);
+
+    Assert.assertEquals(
+        CostBalancerStrategy.computeJointSegmentsCost(segmentA, segmentB),
+        CostBalancerStrategy.computeJointSegmentsCost(segmentB, segmentA),
+        DELTA
     );
-    return segment;
   }
 
   @Test
-  public void testCostBalancerMultiThreadedStrategy()
+  public void testJointSegmentsCostMultipleDatasources()
   {
-    List<ServerHolder> serverHolderList = setupDummyCluster(10, 20);
-    DataSegment segment = getSegment(1000);
+    final DataSegment wikiSegment = CreateDataSegments.ofDatasource(DS_WIKI)
+                                                      .forIntervals(1, Granularities.DAY)
+                                                      .startingAt("2010-01-01")
+                                                      .eachOfSizeInMb(100).get(0);
+    final DataSegment koalaSegment = CreateDataSegments.ofDatasource("koala")
+                                                       .forIntervals(1, Granularities.DAY)
+                                                       .startingAt("2010-01-01")
+                                                       .eachOfSizeInMb(100).get(0);
 
-    BalancerStrategy strategy = new CostBalancerStrategy(
-        MoreExecutors.listeningDecorator(Execs.multiThreaded(4, "CostBalancerStrategyTest-%d"))
+    // Verify that cross datasource cost is twice that of same datasource cost
+    final double crossDatasourceCost =
+        CostBalancerStrategy.computeJointSegmentsCost(koalaSegment, wikiSegment);
+    Assert.assertEquals(
+        2 * crossDatasourceCost,
+        CostBalancerStrategy.computeJointSegmentsCost(wikiSegment, wikiSegment),
+        DELTA
+    );
+    Assert.assertEquals(
+        2 * crossDatasourceCost,
+        CostBalancerStrategy.computeJointSegmentsCost(koalaSegment, koalaSegment),
+        DELTA
     );
-    ServerHolder holder = strategy.findNewSegmentHomeReplicator(segment, serverHolderList);
-    Assert.assertNotNull("Should be able to find a place for new segment!!", holder);
-    Assert.assertEquals("Best Server should be BEST_SERVER", "BEST_SERVER", holder.getServer().getName());
   }
 
   @Test
-  public void testCostBalancerSingleThreadStrategy()
+  public void testJointSegmentsCostWith45DayGap()
   {
-    List<ServerHolder> serverHolderList = setupDummyCluster(10, 20);
-    DataSegment segment = getSegment(1000);
+    // start of 2nd segment - end of 1st segment = 45 days
+    final long gap1Day = TimeUnit.DAYS.toMillis(1);
+    final long gap45Days = 45 * gap1Day;
+
+    // This test establishes that after 45 days, all costs become negligible
+    // (except with ALL granularity)
+
+    // Add extra gap to ensure that segments have no overlap
+    final long gap1Hour = TimeUnit.HOURS.toMillis(1);
+    verifyJointSegmentsCost(GranularityType.HOUR, GranularityType.HOUR, gap1Hour + gap45Days, 0);
+    verifyJointSegmentsCost(GranularityType.HOUR, GranularityType.DAY, gap1Hour + gap45Days, 0);
+
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.DAY, gap1Day + gap45Days, 0);
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.MONTH, gap1Day + gap45Days, 0);
 
-    BalancerStrategy strategy = new CostBalancerStrategy(
-        MoreExecutors.listeningDecorator(Execs.multiThreaded(1, "CostBalancerStrategyTest-%d"))
+    verifyJointSegmentsCost(GranularityType.MONTH, GranularityType.MONTH, 30 * gap1Day + gap45Days, 0);
+    verifyJointSegmentsCost(GranularityType.YEAR, GranularityType.YEAR, 365 * gap1Day + gap45Days, 0);
+  }
+
+  @Test
+  public void testJointSegmentsCostAllGranularity()
+  {
+    // Cost of ALL with other granularities
+    verifyJointSegmentsCost(GranularityType.HOUR, GranularityType.ALL, 0, 138.516732);
+    verifyJointSegmentsCost(GranularityType.DAY, GranularityType.ALL, 0, 3323.962523);
+    verifyJointSegmentsCost(GranularityType.MONTH, GranularityType.ALL, 0, 103043.057744);
+    verifyJointSegmentsCost(GranularityType.YEAR, GranularityType.ALL, 0, 1213248.808913);
+
+    // Self cost of an ALL granularity segment
+    DataSegment segmentAllGranularity =
+        CreateDataSegments.ofDatasource("ds")
+                          .forIntervals(1, Granularities.ALL)
+                          .eachOfSizeInMb(100).get(0);
+    double cost = CostBalancerStrategy.computeJointSegmentsCost(
+        segmentAllGranularity,
+        segmentAllGranularity
     );
-    ServerHolder holder = strategy.findNewSegmentHomeReplicator(segment, serverHolderList);
-    Assert.assertNotNull("Should be able to find a place for new segment!!", holder);
-    Assert.assertEquals("Best Server should be BEST_SERVER", "BEST_SERVER", holder.getServer().getName());
+    Assert.assertTrue(cost >= 3.548e14 && cost <= 3.549e14);
   }
 
   @Test
-  public void testComputeJointSegmentCost()
+  public void testComputeCost()
   {
-    DateTime referenceTime = DateTimes.of("2014-01-01T00:00:00");
-    double segmentCost = CostBalancerStrategy.computeJointSegmentsCost(
-        getSegment(
-            100,
-            "DUMMY",
-            new Interval(
-                referenceTime,
-                referenceTime.plusHours(1)
-            )
-        ),
-        getSegment(
-            101,
-            "DUMMY",
-            new Interval(
-                referenceTime.minusHours(2),
-                referenceTime.minusHours(2).plusHours(1)
-            )
-        )
+    // Create segments for different granularities
+    final List<DataSegment> daySegments =
+        CreateDataSegments.ofDatasource(DS_WIKI)
+                          .forIntervals(10, Granularities.DAY)
+                          .startingAt("2022-01-01")
+                          .withNumPartitions(10)
+                          .eachOfSizeInMb(100);
+
+    final List<DataSegment> monthSegments =
+        CreateDataSegments.ofDatasource(DS_WIKI)
+                          .forIntervals(10, Granularities.MONTH)
+                          .startingAt("2022-03-01")
+                          .withNumPartitions(10)
+                          .eachOfSizeInMb(100);
+
+    final List<DataSegment> yearSegments =
+        CreateDataSegments.ofDatasource(DS_WIKI)
+                          .forIntervals(1, Granularities.YEAR)
+                          .startingAt("2023-01-01")
+                          .withNumPartitions(30)
+                          .eachOfSizeInMb(100);
+
+    // Distribute the segments randomly amongst 2 servers
+    final List<DataSegment> segments = new ArrayList<>(daySegments);
+    segments.addAll(monthSegments);
+    segments.addAll(yearSegments);
+
+    List<DruidServer> historicals = IntStream.range(0, 3)
+                                             .mapToObj(i -> createHistorical())
+                                             .collect(Collectors.toList());
+    final Random random = new Random(100);
+    segments.forEach(
+        segment -> historicals.get(random.nextInt(historicals.size()))
+                              .addDataSegment(segment)
     );
 
-    Assert.assertEquals(
-        CostBalancerStrategy.INV_LAMBDA_SQUARE * CostBalancerStrategy.intervalCost(
-            1 * CostBalancerStrategy.LAMBDA,
-            -2 * CostBalancerStrategy.LAMBDA,
-            -1 * CostBalancerStrategy.LAMBDA
-        ) * 2,
-        segmentCost, 1e-6);
+    // Create ServerHolder for each server
+    final List<ServerHolder> serverHolders = historicals.stream().map(
+        server -> new ServerHolder(server.toImmutableDruidServer(), new LoadQueuePeonTester())
+    ).collect(Collectors.toList());
+
+    // Verify costs for DAY, MONTH and YEAR segments
+    verifyServerCosts(
+        daySegments.get(0),
+        serverHolders,
+        5191.500804, 8691.392080, 6418.467818
+    );
+    verifyServerCosts(
+        monthSegments.get(0),
+        serverHolders,
+        301935.940609, 301935.940606, 304333.677669
+    );
+    verifyServerCosts(
+        yearSegments.get(0),
+        serverHolders,
+        8468764.380437, 12098919.896931, 14501440.169452
+    );
+
+    // Verify costs for an ALL granularity segment
+    DataSegment allGranularitySegment =
+        CreateDataSegments.ofDatasource(DS_WIKI)
+                          .forIntervals(1, Granularities.ALL)
+                          .eachOfSizeInMb(100).get(0);
+    verifyServerCosts(
+        allGranularitySegment,
+        serverHolders,
+        1.1534173737329768e7,
+        1.6340633534241956e7,
+        1.9026400521582970e7
+    );
   }
 
-  @Test
-  public void testIntervalCost()
+  private void verifyServerCosts(
+      DataSegment segment,
+      List<ServerHolder> serverHolders,
+      double... expectedCosts
+  )
+  {
+    for (int i = 0; i < serverHolders.size(); ++i) {
+      double observedCost = strategy.computeCost(segment, serverHolders.get(i), true);
+      Assert.assertEquals(expectedCosts[i], observedCost, DELTA);
+    }
+  }
+
+  private void verifyJointSegmentsCost(
+      GranularityType granularityX,
+      GranularityType granularityY,
+      long startGapMillis,
+      double expectedCost
+  )
+  {
+    final DataSegment segmentX =
+        CreateDataSegments.ofDatasource(DS_WIKI)
+                          .forIntervals(1, granularityX.getDefaultGranularity())
+                          .startingAt("2012-10-24")
+                          .eachOfSizeInMb(100).get(0);
+
+    long startTimeY = segmentX.getInterval().getStartMillis() + startGapMillis;
+    final DataSegment segmentY =
+        CreateDataSegments.ofDatasource(DS_WIKI)
+                          .forIntervals(1, granularityY.getDefaultGranularity())
+                          .startingAt(startTimeY)
+                          .eachOfSizeInMb(100).get(0);
+
+    double observedCost = CostBalancerStrategy.computeJointSegmentsCost(segmentX, segmentY);
+    Assert.assertEquals(expectedCost, observedCost, DELTA);
+  }
+
+  private DruidServer createHistorical()
   {
-    // additivity
-    Assert.assertEquals(CostBalancerStrategy.intervalCost(1, 1, 3),
-                        CostBalancerStrategy.intervalCost(1, 1, 2) +
-                        CostBalancerStrategy.intervalCost(1, 2, 3), 1e-6);
-
-    Assert.assertEquals(CostBalancerStrategy.intervalCost(2, 1, 3),
-                        CostBalancerStrategy.intervalCost(2, 1, 2) +
-                        CostBalancerStrategy.intervalCost(2, 2, 3), 1e-6);
-
-    Assert.assertEquals(CostBalancerStrategy.intervalCost(3, 1, 2),
-                        CostBalancerStrategy.intervalCost(1, 1, 2) +
-                        CostBalancerStrategy.intervalCost(1, 0, 1) +
-                        CostBalancerStrategy.intervalCost(1, 1, 2), 1e-6);
-
-    // no overlap [0, 1) [1, 2)
-    Assert.assertEquals(0.3995764, CostBalancerStrategy.intervalCost(1, 1, 2), 1e-6);
-    // no overlap [0, 1) [-1, 0)
-    Assert.assertEquals(0.3995764, CostBalancerStrategy.intervalCost(1, -1, 0), 1e-6);
-
-    // exact overlap [0, 1), [0, 1)
-    Assert.assertEquals(0.7357589, CostBalancerStrategy.intervalCost(1, 0, 1), 1e-6);
-    // exact overlap [0, 2), [0, 2)
-    Assert.assertEquals(2.270671, CostBalancerStrategy.intervalCost(2, 0, 2), 1e-6);
-    // partial overlap [0, 2), [1, 3)
-    Assert.assertEquals(1.681908, CostBalancerStrategy.intervalCost(2, 1, 3), 1e-6);
-    // partial overlap [0, 2), [1, 2)
-    Assert.assertEquals(1.135335, CostBalancerStrategy.intervalCost(2, 1, 2), 1e-6);
-    // partial overlap [0, 2), [0, 1)
-    Assert.assertEquals(1.135335, CostBalancerStrategy.intervalCost(2, 0, 1), 1e-6);
-    // partial overlap [0, 3), [1, 2)
-    Assert.assertEquals(1.534912, CostBalancerStrategy.intervalCost(3, 1, 2), 1e-6);
+    String serverName = "hist_" + uniqueServerId++;
+    return new DruidServer(serverName, serverName, null, 10L << 30, ServerType.HISTORICAL, "hot", 1);
   }
+
 }
diff --git a/server/src/test/java/org/apache/druid/server/coordinator/CreateDataSegments.java b/server/src/test/java/org/apache/druid/server/coordinator/CreateDataSegments.java
index 8f2c123891..9b1aa71131 100644
--- a/server/src/test/java/org/apache/druid/server/coordinator/CreateDataSegments.java
+++ b/server/src/test/java/org/apache/druid/server/coordinator/CreateDataSegments.java
@@ -20,6 +20,8 @@
 package org.apache.druid.server.coordinator;
 
 import org.apache.druid.java.util.common.DateTimes;
+import org.apache.druid.java.util.common.Intervals;
+import org.apache.druid.java.util.common.granularity.Granularities;
 import org.apache.druid.java.util.common.granularity.Granularity;
 import org.apache.druid.segment.IndexIO;
 import org.apache.druid.timeline.DataSegment;
@@ -30,6 +32,7 @@ import org.joda.time.Interval;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
+import java.util.Objects;
 
 /**
  * Test utility to create {@link DataSegment}s for a given datasource.
@@ -40,8 +43,8 @@ public class CreateDataSegments
 
   private DateTime startTime;
   private Granularity granularity;
-  private int numPartitions;
-  private int numIntervals;
+  private int numPartitions = 1;
+  private int numIntervals = 1;
 
   public static CreateDataSegments ofDatasource(String datasource)
   {
@@ -66,6 +69,12 @@ public class CreateDataSegments
     return this;
   }
 
+  public CreateDataSegments startingAt(long startOfFirstInterval)
+  {
+    this.startTime = DateTimes.utc(startOfFirstInterval);
+    return this;
+  }
+
   public CreateDataSegments withNumPartitions(int numPartitions)
   {
     this.numPartitions = numPartitions;
@@ -74,12 +83,19 @@ public class CreateDataSegments
 
   public List<DataSegment> eachOfSizeInMb(long sizeMb)
   {
-    final List<DataSegment> segments = new ArrayList<>();
+    boolean isEternityInterval = Objects.equals(granularity, Granularities.ALL);
+    if (isEternityInterval) {
+      numIntervals = 1;
+    }
 
     int uniqueIdInInterval = 0;
     DateTime nextStart = startTime;
+
+    final List<DataSegment> segments = new ArrayList<>();
     for (int numInterval = 0; numInterval < numIntervals; ++numInterval) {
-      Interval nextInterval = new Interval(nextStart, granularity.increment(nextStart));
+      Interval nextInterval = isEternityInterval
+                              ? Intervals.ETERNITY
+                              : new Interval(nextStart, granularity.increment(nextStart));
       for (int numPartition = 0; numPartition < numPartitions; ++numPartition) {
         segments.add(
             new NumberedDataSegment(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org