You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@cassandra.apache.org by ma...@apache.org on 2020/09/21 08:02:44 UTC

[cassandra-diff] branch master updated: Allow optional query retry

This is an automated email from the ASF dual-hosted git repository.

marcuse pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/cassandra-diff.git


The following commit(s) were added to refs/heads/master by this push:
     new 4c9bc4f  Allow optional query retry
4c9bc4f is described below

commit 4c9bc4f4e3fd7d23b1284c89266ffbf10b8f0183
Author: Yifan Cai <yi...@apple.com>
AuthorDate: Tue Sep 15 22:39:56 2020 -0700

    Allow optional query retry
    
    Patch by Yifan Cai; reviewed by marcuse for CASSANDRA-16125
---
 .../diff/ExponentialRetryStrategyProvider.java     | 102 +++++++++++++++++++++
 .../apache/cassandra/diff/JobConfiguration.java    |  12 +++
 .../org/apache/cassandra/diff/RetryStrategy.java   |  45 +++++++++
 .../cassandra/diff/RetryStrategyProvider.java      |  46 ++++++++++
 .../cassandra/diff/YamlJobConfiguration.java       |   5 +
 .../diff/ExponentialRetryStrategyTest.java         | 100 ++++++++++++++++++++
 .../apache/cassandra/diff/NoRetryStrategyTest.java |  26 ++++++
 .../cassandra/diff/YamlJobConfigurationTest.java   |  25 +++++
 common/src/test/resources/testconfig.yaml          |   4 +
 .../org/apache/cassandra/diff/DiffCluster.java     |  22 +++--
 .../java/org/apache/cassandra/diff/Differ.java     |  13 ++-
 .../apache/cassandra/diff/PartitionComparator.java |  11 ++-
 .../diff/AbstractMockJobConfiguration.java         |   5 +
 .../cassandra/diff/PartitionComparatorTest.java    |   2 +-
 .../apache/cassandra/diff/RangeComparatorTest.java |  11 ++-
 15 files changed, 409 insertions(+), 20 deletions(-)

diff --git a/common/src/main/java/org/apache/cassandra/diff/ExponentialRetryStrategyProvider.java b/common/src/main/java/org/apache/cassandra/diff/ExponentialRetryStrategyProvider.java
new file mode 100644
index 0000000..639d14f
--- /dev/null
+++ b/common/src/main/java/org/apache/cassandra/diff/ExponentialRetryStrategyProvider.java
@@ -0,0 +1,102 @@
+package org.apache.cassandra.diff;
+
+import java.util.concurrent.TimeUnit;
+
+import com.google.common.base.Preconditions;
+import com.google.common.util.concurrent.Uninterruptibles;
+
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.ExponentialRetryStrategy.BASE_DELAY_MS_KEY;
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.ExponentialRetryStrategy.DEFAULT_BASE_DELAY_MS;
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.ExponentialRetryStrategy.DEFAULT_TOTAL_DELAY_MS;
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.ExponentialRetryStrategy.TOTAL_DELAY_MS_KEY;
+
+public class ExponentialRetryStrategyProvider extends RetryStrategyProvider {
+    public ExponentialRetryStrategyProvider(JobConfiguration.RetryOptions retryOptions) {
+        super(retryOptions);
+    }
+
+    @Override
+    public RetryStrategy get() {
+        long baseDelayMs = Long.parseLong(retryOptions.getOrDefault(BASE_DELAY_MS_KEY, DEFAULT_BASE_DELAY_MS));
+        long totalDelayMs = Long.parseLong(retryOptions.getOrDefault(TOTAL_DELAY_MS_KEY, DEFAULT_TOTAL_DELAY_MS));
+        return new ExponentialRetryStrategy(baseDelayMs, totalDelayMs);
+    }
+
+    static class ExponentialRetryStrategy extends RetryStrategy {
+        public final static String BASE_DELAY_MS_KEY = "base_delay_ms";
+        public final static String TOTAL_DELAY_MS_KEY = "total_delay_ms";
+        final static String DEFAULT_BASE_DELAY_MS = String.valueOf(TimeUnit.SECONDS.toMillis(1));
+        final static String DEFAULT_TOTAL_DELAY_MS = String.valueOf(TimeUnit.MINUTES.toMillis(30));
+
+        private final Exponential exponential;
+        private int attempts = 0;
+
+        public ExponentialRetryStrategy(long baseDelayMs, long totalDelayMs) {
+            this.exponential = new Exponential(baseDelayMs, totalDelayMs);
+        }
+
+        @Override
+        protected boolean shouldRetry() {
+            long pauseTimeMs = exponential.get(attempts);
+            if (pauseTimeMs > 0) {
+                Uninterruptibles.sleepUninterruptibly(pauseTimeMs, TimeUnit.MILLISECONDS);
+                attempts += 1;
+                return true;
+            }
+            return false;
+        }
+
+        @Override
+        public String toString() {
+            return String.format("%s(baseDelayMs: %s, totalDelayMs: %s, currentAttempts: %s)",
+                                 this.getClass().getSimpleName(), exponential.baseDelayMs, exponential.totalDelayMs, attempts);
+        }
+    }
+
+    /**
+     * Calculate the pause time exponentially, according to the attempts.
+     * The total delay is capped at totalDelayMs, meaning the sum of all the previous pauses cannot exceed it.
+     */
+    static class Exponential {
+        // base delay in ms used to calculate the next pause time
+        private final long baseDelayMs;
+        // total delay in ms permitted
+        private final long totalDelayMs;
+
+        Exponential(long baseDelayMs, long totalDelayMs) {
+            Preconditions.checkArgument(baseDelayMs <= totalDelayMs, "baseDelayMs cannot be greater than totalDelayMs");
+            this.baseDelayMs = baseDelayMs;
+            this.totalDelayMs = totalDelayMs;
+        }
+
+        /**
+         * Calculate the pause time based on attempts.
+         * It is guaranteed that the all the pauses do not exceed totalDelayMs.
+         * @param attempts, number of attempts, starts with 0.
+         * @return the next pasuse time in milliseconds, or negtive if no longer allowed.
+         */
+        long get(int attempts) {
+            long nextMaybe = baseDelayMs << attempts; // Do not care about overflow. pausedInTotal() corrects the value
+            if (attempts == 0) { // first retry
+                return nextMaybe;
+            } else {
+                long pausedInTotal = pausedInTotal(attempts);
+                if (pausedInTotal < totalDelayMs) {
+                    return Math.min(totalDelayMs - pausedInTotal, nextMaybe); // adjust the next pause time if possible
+                }
+                return -1; // the previous retries have exhausted the permits
+            }
+        }
+
+        // Returns the total pause time according to the `attempts`,
+        // i.e. [0, attempts), which is guaranteed to be greater than or equal to 0.
+        // No overflow can happen.
+        private long pausedInTotal(int attempts) {
+            // take care of overflow. Such long pause time is not realistic though.
+            if (attempts >= Long.numberOfLeadingZeros(baseDelayMs))
+                return totalDelayMs;
+            long result = (baseDelayMs << attempts) - baseDelayMs; // X^1 + X^2 ... + X^n = X^(n+1) - X
+            return Math.min(totalDelayMs, result);
+        }
+    }
+}
diff --git a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
index cf12ea3..7a20b30 100644
--- a/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
+++ b/common/src/main/java/org/apache/cassandra/diff/JobConfiguration.java
@@ -20,6 +20,7 @@
 package org.apache.cassandra.diff;
 
 import java.io.Serializable;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -86,6 +87,17 @@ public interface JobConfiguration extends Serializable {
 
     MetadataKeyspaceOptions metadataOptions();
 
+    /**
+     * Contains the options that specify the retry strategy for retrieving data at the application level.
+     * Note that it is different than cassandra java driver's {@link com.datastax.driver.core.policies.RetryPolicy},
+     * which is evaluated at the Netty worker threads.
+     */
+    RetryOptions retryOptions();
+
     Map<String, String> clusterConfig(String identifier);
 
+    // Just an alias
+    public static class RetryOptions extends HashMap<String, String> {
+    }
+
 }
diff --git a/common/src/main/java/org/apache/cassandra/diff/RetryStrategy.java b/common/src/main/java/org/apache/cassandra/diff/RetryStrategy.java
new file mode 100644
index 0000000..b0cd7c6
--- /dev/null
+++ b/common/src/main/java/org/apache/cassandra/diff/RetryStrategy.java
@@ -0,0 +1,45 @@
+package org.apache.cassandra.diff;
+
+import java.util.concurrent.Callable;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public abstract class RetryStrategy {
+    private final static Logger logger = LoggerFactory.getLogger(RetryStrategy.class);
+
+    /**
+     * Decide whether retry is desired or not.
+     * @return true to retry, see {@link #retry(Callable)}.
+     *         return false to re-throw the exception.
+     */
+    protected abstract boolean shouldRetry();
+
+    public final <T> T retry(Callable<T> retryable) throws Exception {
+        while (true) {
+            try {
+                return retryable.call();
+            }
+            catch (Exception exception) {
+                if (!shouldRetry()) {
+                    throw exception;
+                }
+                logger.warn("Retry with " + toString());
+            }
+        }
+    }
+
+    public static class NoRetry extends RetryStrategy {
+        public final static RetryStrategy INSTANCE = new NoRetry();
+
+        @Override
+        public boolean shouldRetry() {
+            return false;
+        }
+
+        @Override
+        public String toString() {
+            return this.getClass().getSimpleName();
+        }
+    }
+}
diff --git a/common/src/main/java/org/apache/cassandra/diff/RetryStrategyProvider.java b/common/src/main/java/org/apache/cassandra/diff/RetryStrategyProvider.java
new file mode 100644
index 0000000..ec29bb7
--- /dev/null
+++ b/common/src/main/java/org/apache/cassandra/diff/RetryStrategyProvider.java
@@ -0,0 +1,46 @@
+package org.apache.cassandra.diff;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Provides new RetryStrategy instances.
+ * Use abstract class instead of interface in order to retain the referece to retryOptions;
+ */
+public abstract class RetryStrategyProvider {
+    protected final JobConfiguration.RetryOptions retryOptions;
+
+    public RetryStrategyProvider(JobConfiguration.RetryOptions retryOptions) {
+        this.retryOptions = retryOptions;
+    }
+
+    /**
+     * Create a new instance of RetryStrategy.
+     */
+    public abstract RetryStrategy get();
+
+
+    public final static String IMPLEMENTATION_KEY = "impl";
+    private final static Logger logger = LoggerFactory.getLogger(RetryStrategyProvider.class);
+
+    /**
+     * Create a RetryStrategyProvider based on {@param retryOptions}.
+     */
+    public static RetryStrategyProvider create(JobConfiguration.RetryOptions retryOptions) {
+        try {
+            String implClass = retryOptions.get(IMPLEMENTATION_KEY);
+            return (RetryStrategyProvider) Class.forName(implClass)
+                                                .getConstructor(JobConfiguration.RetryOptions.class)
+                                                .newInstance(retryOptions);
+        } catch (Exception ex) {
+            logger.warn("Unable to create RetryStrategyProvider. Use the default provider, NoRetry.", ex);
+
+            return new RetryStrategyProvider(retryOptions) {
+                @Override
+                public RetryStrategy get() {
+                    return RetryStrategy.NoRetry.INSTANCE;
+                }
+            };
+        }
+    }
+}
diff --git a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
index c49da20..359466a 100644
--- a/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
+++ b/common/src/main/java/org/apache/cassandra/diff/YamlJobConfiguration.java
@@ -47,6 +47,7 @@ public class YamlJobConfiguration implements JobConfiguration {
     public Map<String, Map<String, String>> cluster_config;
     public String specific_tokens = null;
     public String disallowed_tokens = null;
+    public RetryOptions retry_options;
 
     public static YamlJobConfiguration load(InputStream inputStream) {
         Yaml yaml = new Yaml(new CustomClassLoaderConstructor(YamlJobConfiguration.class,
@@ -102,6 +103,10 @@ public class YamlJobConfiguration implements JobConfiguration {
         return metadata_options;
     }
 
+    public RetryOptions retryOptions() {
+        return retry_options;
+    }
+
     public Map<String, String> clusterConfig(String identifier) {
         return cluster_config.get(identifier);
     }
diff --git a/common/src/test/java/org/apache/cassandra/diff/ExponentialRetryStrategyTest.java b/common/src/test/java/org/apache/cassandra/diff/ExponentialRetryStrategyTest.java
new file mode 100644
index 0000000..26940d7
--- /dev/null
+++ b/common/src/test/java/org/apache/cassandra/diff/ExponentialRetryStrategyTest.java
@@ -0,0 +1,100 @@
+package org.apache.cassandra.diff;
+
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.Exponential;
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.ExponentialRetryStrategy;
+
+public class ExponentialRetryStrategyTest {
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+
+    @Test
+    public void testPauseTimeIncreaseExponentially() {
+        long base = 10;
+        long total = 1000;
+        Exponential exponential = new Exponential(base, total);
+        long totalSoFar = 0;
+        for (int i = 0; i < 100; i ++) {
+            long actual = exponential.get(i);
+            long expected = base << i;
+            if (totalSoFar >= total) {
+                expected = -1;
+            } else {
+                if (totalSoFar + expected > total) {
+                    expected = total - totalSoFar; // adjust the pause time for the last valid pause.
+                }
+                totalSoFar += expected;
+            }
+            Assert.assertEquals("Exponential generates unexpected sequence at iteration#" + i, expected, actual);
+        }
+        Assert.assertEquals("The total pause time is not capped at totalDelayMs", total, totalSoFar);
+    }
+
+    @Test
+    public void testWrongArguments() {
+        expectedException.expect(IllegalArgumentException.class);
+        expectedException.expectMessage("baseDelayMs cannot be greater than totalDelayMs");
+        new Exponential(10, 1);
+    }
+
+    @Test
+    public void testToString() {
+        ExponentialRetryStrategyProvider provider = new ExponentialRetryStrategyProvider(new JobConfiguration.RetryOptions());
+        String output = provider.get().toString();
+        Assert.assertEquals("ExponentialRetryStrategy(baseDelayMs: 1000, totalDelayMs: 1800000, currentAttempts: 0)",
+                            output);
+    }
+
+    @Test
+    public void testSuccessAfterRetry() throws Exception {
+        AtomicInteger retryCount = new AtomicInteger(0);
+        ExponentialRetryStrategy strategy = new ExponentialRetryStrategy(1, 1000);
+        int result = strategy.retry(() -> {
+            if (retryCount.getAndIncrement() < 2) {
+                throw new RuntimeException("fail");
+            }
+            return 1;
+        });
+        Assert.assertEquals(1, result);
+        Assert.assertEquals(3, retryCount.get());
+    }
+
+    @Test
+    public void testFailureAfterAllRetries() throws Exception {
+        AtomicInteger execCount = new AtomicInteger(0);
+        ExponentialRetryStrategy strategy = new ExponentialRetryStrategy(1, 2);
+        expectedException.expect(RuntimeException.class);
+        expectedException.expectMessage("fail at execution#2"); // 0 based
+        // the lambda runs 3 times at timestamp 0, 1, 2 and fail
+        strategy.retry(() -> {
+            throw new RuntimeException("fail at execution#" + execCount.getAndIncrement());
+        });
+    }
+
+    @Test
+    public void testOverflowPrevention() {
+        Random rand = new Random();
+        for (int i = 0; i < 1000; i++) {
+            long base = rand.nextInt(100000) + 1; // [1, 100000]
+            int leadingZeros = Long.numberOfLeadingZeros(base);
+            Exponential exponential = new Exponential(base, Long.MAX_VALUE);
+            Assert.assertTrue("The last attempt that still generate valid pause time. Failed with base: " + base,
+                              exponential.get(leadingZeros - 1) > 0);
+            Assert.assertEquals("Failed with base: " + base, -1, exponential.get(leadingZeros));
+        }
+    }
+
+    private JobConfiguration.RetryOptions retryOptions(long baseDelayMs, long totalDelayMs) {
+        return new JobConfiguration.RetryOptions() {{
+            put(ExponentialRetryStrategy.BASE_DELAY_MS_KEY, String.valueOf(baseDelayMs));
+            put(ExponentialRetryStrategy.TOTAL_DELAY_MS_KEY, String.valueOf(totalDelayMs));
+        }};
+    }
+}
diff --git a/common/src/test/java/org/apache/cassandra/diff/NoRetryStrategyTest.java b/common/src/test/java/org/apache/cassandra/diff/NoRetryStrategyTest.java
new file mode 100644
index 0000000..6428444
--- /dev/null
+++ b/common/src/test/java/org/apache/cassandra/diff/NoRetryStrategyTest.java
@@ -0,0 +1,26 @@
+package org.apache.cassandra.diff;
+
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.Assert;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+public class NoRetryStrategyTest {
+    @Rule
+    public ExpectedException expectedException = ExpectedException.none();
+
+    @Test
+    public void testNoRetry() throws Exception {
+        RetryStrategy strategy = RetryStrategy.NoRetry.INSTANCE;
+        Assert.assertFalse("NoRetry should always not retry",
+                           strategy.shouldRetry());
+        AtomicInteger execCount = new AtomicInteger(0);
+        expectedException.expect(RuntimeException.class);
+        expectedException.expectMessage("fail at execution#0"); // no retry
+        strategy.retry(() -> {
+            throw new RuntimeException("fail at execution#" + execCount.getAndIncrement());
+        });
+    }
+}
diff --git a/common/src/test/java/org/apache/cassandra/diff/YamlJobConfigurationTest.java b/common/src/test/java/org/apache/cassandra/diff/YamlJobConfigurationTest.java
index 39c43ee..7aa72ab 100644
--- a/common/src/test/java/org/apache/cassandra/diff/YamlJobConfigurationTest.java
+++ b/common/src/test/java/org/apache/cassandra/diff/YamlJobConfigurationTest.java
@@ -3,6 +3,10 @@ package org.apache.cassandra.diff;
 import org.junit.Assert;
 import org.junit.Test;
 
+import org.hamcrest.CoreMatchers;
+
+import static org.apache.cassandra.diff.ExponentialRetryStrategyProvider.ExponentialRetryStrategy;
+
 public class YamlJobConfigurationTest {
     @Test
     public void testLoadYaml() {
@@ -12,6 +16,11 @@ public class YamlJobConfigurationTest {
             Assert.assertTrue("Keyspace segment is not loaded correctly", kt.keyspace.contains("ks"));
             Assert.assertTrue("Table segment is not loaded correctly", kt.table.contains("tb"));
         });
+        JobConfiguration.RetryOptions retryOptions = jobConfiguration.retryOptions();
+        Assert.assertNotNull("retry_options not defined", retryOptions);
+        Assert.assertNotNull("impl not defined", retryOptions.get(ExponentialRetryStrategyProvider.IMPLEMENTATION_KEY));
+        Assert.assertNotNull("base_delay_ms not defined", retryOptions.get(ExponentialRetryStrategy.BASE_DELAY_MS_KEY));
+        Assert.assertNotNull("total_delay_ms not defined", retryOptions.get(ExponentialRetryStrategy.TOTAL_DELAY_MS_KEY));
     }
 
     @Test
@@ -31,6 +40,22 @@ public class YamlJobConfigurationTest {
         Assert.assertFalse("It should not be in the discover mode", jobConfiguration.shouldAutoDiscoverTables());
     }
 
+    @Test
+    public void testInstatiateRetryStrategyProvider() {
+        JobConfiguration withExponentialRetry = load("testconfig.yaml");
+        RetryStrategyProvider provider = RetryStrategyProvider.create(withExponentialRetry.retryOptions());
+        Assert.assertThat(provider, CoreMatchers.instanceOf(ExponentialRetryStrategyProvider.class));
+        Assert.assertThat(provider.get(), CoreMatchers.instanceOf(ExponentialRetryStrategy.class));
+
+        // empty retry option leads to NoRetry strategy
+        provider = RetryStrategyProvider.create(new JobConfiguration.RetryOptions());
+        Assert.assertThat(provider.get(), CoreMatchers.sameInstance(RetryStrategy.NoRetry.INSTANCE));
+
+        // null retry option leads to NoRetry strategy
+        provider = RetryStrategyProvider.create(null);
+        Assert.assertThat(provider.get(), CoreMatchers.sameInstance(RetryStrategy.NoRetry.INSTANCE));
+    }
+
     private JobConfiguration load(String filename) {
         return YamlJobConfiguration.load(getClass().getClassLoader().getResourceAsStream(filename));
     }
diff --git a/common/src/test/resources/testconfig.yaml b/common/src/test/resources/testconfig.yaml
index a860e48..be953fb 100644
--- a/common/src/test/resources/testconfig.yaml
+++ b/common/src/test/resources/testconfig.yaml
@@ -49,3 +49,7 @@ cluster_config:
     contact_points: "127.0.0.1"
     port: "9042"
     dc: "datacenter1"
+retry_options:
+  impl: "org.apache.cassandra.diff.ExponentialRetryStrategyProvider"
+  base_delay_ms: "1"
+  total_delay_ms: "2"
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/DiffCluster.java b/spark-job/src/main/java/org/apache/cassandra/diff/DiffCluster.java
index a60b8f7..da07890 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/DiffCluster.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/DiffCluster.java
@@ -25,7 +25,6 @@ import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.ExecutionException;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 import com.google.common.collect.AbstractIterator;
@@ -81,6 +80,7 @@ public class DiffCluster implements AutoCloseable
     private final int tokenScanFetchSize;
     private final int partitionReadFetchSize;
     private final int readTimeoutMillis;
+    private final RetryStrategyProvider retryStrategyProvider;
 
     private final AtomicBoolean stopped = new AtomicBoolean(false);
 
@@ -90,7 +90,8 @@ public class DiffCluster implements AutoCloseable
                        RateLimiter getPartitionRateLimiter,
                        int tokenScanFetchSize,
                        int partitionReadFetchSize,
-                       int readTimeoutMillis)
+                       int readTimeoutMillis,
+                       RetryStrategyProvider retryStrategyProvider)
 
     {
         this.consistencyLevel = consistencyLevel;
@@ -103,13 +104,17 @@ public class DiffCluster implements AutoCloseable
         this.tokenScanFetchSize = tokenScanFetchSize;
         this.partitionReadFetchSize = partitionReadFetchSize;
         this.readTimeoutMillis = readTimeoutMillis;
+        this.retryStrategyProvider = retryStrategyProvider;
     }
 
     public Iterator<PartitionKey> getPartitionKeys(KeyspaceTablePair table, final BigInteger prevToken, final BigInteger token) {
         try {
-            return Uninterruptibles.getUninterruptibly(fetchPartitionKeys(table, prevToken, token));
+            RetryStrategy retryStrategy = retryStrategyProvider.get();
+            return retryStrategy.retry(
+                () -> Uninterruptibles.getUninterruptibly(fetchPartitionKeys(table, prevToken, token))
+            );
         }
-        catch (ExecutionException ex) {
+        catch (Exception ex) {
             throw new RuntimeException(String.format("Unable to get partition keys (%s, %s] in table (%s) from cluster (%s)",
                                                      prevToken, token, table, clusterId.name()),
                                        ex);
@@ -144,9 +149,12 @@ public class DiffCluster implements AutoCloseable
 
     public Iterator<Row> getPartition(TableSpec table, PartitionKey key, boolean shouldReverse) {
         try {
-            return readPartition(table.getTable(), key, shouldReverse)
-                       .getUninterruptibly()
-                       .iterator();
+            RetryStrategy retryStrategy = retryStrategyProvider.get();
+            return retryStrategy.retry(
+                () -> readPartition(table.getTable(), key, shouldReverse)
+                          .getUninterruptibly()
+                          .iterator()
+            );
         }
         catch (Exception ex) {
             throw new RuntimeException(String.format("Unable to get partition (%s) in table (%s) from cluster (%s)",
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
index 1ec78f0..b2cb527 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/Differ.java
@@ -62,6 +62,7 @@ public class Differ implements Serializable
     private final DiffJob.TrackerProvider trackerProvider;
     private final double reverseReadProbability;
     private final SpecificTokens specificTokens;
+    private final RetryStrategyProvider retryStrategyProvider;
 
     private static DiffCluster srcDiffCluster;
     private static DiffCluster targetDiffCluster;
@@ -100,6 +101,7 @@ public class Differ implements Serializable
         rateLimiter = RateLimiter.create(perExecutorRateLimit);
         this.reverseReadProbability = config.reverseReadProbability();
         this.specificTokens = config.specificTokens();
+        this.retryStrategyProvider = RetryStrategyProvider.create(config.retryOptions());
         synchronized (Differ.class)
         {
             /*
@@ -116,7 +118,8 @@ public class Differ implements Serializable
                                                  rateLimiter,
                                                  config.tokenScanFetchSize(),
                                                  config.partitionReadFetchSize(),
-                                                 config.readTimeoutMillis());
+                                                 config.readTimeoutMillis(),
+                                                 retryStrategyProvider);
             }
 
             if (targetDiffCluster == null)
@@ -127,7 +130,8 @@ public class Differ implements Serializable
                                                     rateLimiter,
                                                     config.tokenScanFetchSize(),
                                                     config.partitionReadFetchSize(),
-                                                    config.readTimeoutMillis());
+                                                    config.readTimeoutMillis(),
+                                                    retryStrategyProvider);
             }
 
             if (journalSession == null)
@@ -212,7 +216,7 @@ public class Differ implements Serializable
                 boolean reverse = context.shouldReverse();
                 Iterator<Row> source = fetchRows(context, key, reverse, DiffCluster.Type.SOURCE);
                 Iterator<Row> target = fetchRows(context, key, reverse, DiffCluster.Type.TARGET);
-                return new PartitionComparator(context.table, source, target);
+                return new PartitionComparator(context.table, source, target, retryStrategyProvider);
             };
 
         RangeComparator rangeComparator = new RangeComparator(context,
@@ -230,7 +234,8 @@ public class Differ implements Serializable
         Callable<Iterator<Row>> rows = () -> type == DiffCluster.Type.SOURCE
                                              ? context.source.getPartition(context.table, key, shouldReverse)
                                              : context.target.getPartition(context.table, key, shouldReverse);
-        return ClusterSourcedException.catches(type, rows);
+        RetryStrategy retryStrategy = retryStrategyProvider.get();
+        return ClusterSourcedException.catches(type, () -> retryStrategy.retry(rows));
     }
 
     @VisibleForTesting
diff --git a/spark-job/src/main/java/org/apache/cassandra/diff/PartitionComparator.java b/spark-job/src/main/java/org/apache/cassandra/diff/PartitionComparator.java
index 8aefb49..f0b23e1 100644
--- a/spark-job/src/main/java/org/apache/cassandra/diff/PartitionComparator.java
+++ b/spark-job/src/main/java/org/apache/cassandra/diff/PartitionComparator.java
@@ -36,13 +36,16 @@ public class PartitionComparator implements Callable<PartitionStats> {
     private final TableSpec tableSpec;
     private final Iterator<Row> source;
     private final Iterator<Row> target;
+    private final RetryStrategyProvider retryStrategyProvider;
 
     public PartitionComparator(TableSpec tableSpec,
                                Iterator<Row> source,
-                               Iterator<Row> target) {
+                               Iterator<Row> target,
+                               RetryStrategyProvider retryStrategyProvider) {
         this.tableSpec = tableSpec;
         this.source = source;
         this.target = target;
+        this.retryStrategyProvider = retryStrategyProvider;
     }
 
     public PartitionStats call() {
@@ -84,14 +87,16 @@ public class PartitionComparator implements Callable<PartitionStats> {
         Callable<Boolean> hasNext = () -> type == Type.SOURCE
                                           ? source.hasNext()
                                           : target.hasNext();
-        return ClusterSourcedException.catches(type, hasNext);
+        RetryStrategy retryStrategy = retryStrategyProvider.get();
+        return ClusterSourcedException.catches(type, () -> retryStrategy.retry(hasNext));
     }
 
     private Row getNextRow(Type type) {
         Callable<Row> next = () -> type == Type.SOURCE
                                    ? source.next()
                                    : target.next();
-        return ClusterSourcedException.catches(type, next);
+        RetryStrategy retryStrategy = retryStrategyProvider.get();
+        return ClusterSourcedException.catches(type, () -> retryStrategy.retry(next));
     }
 
     private boolean clusteringsEqual(Row source, Row target) {
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/AbstractMockJobConfiguration.java b/spark-job/src/test/java/org/apache/cassandra/diff/AbstractMockJobConfiguration.java
index 5f9d22a..c05f6e2 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/AbstractMockJobConfiguration.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/AbstractMockJobConfiguration.java
@@ -76,4 +76,9 @@ public abstract class AbstractMockJobConfiguration implements JobConfiguration {
     public Map<String, String> clusterConfig(String identifier) {
         throw uoe;
     }
+
+    @Override
+    public RetryOptions retryOptions() {
+        throw uoe;
+    }
 }
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/PartitionComparatorTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/PartitionComparatorTest.java
index 9cd1892..dc886e4 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/PartitionComparatorTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/PartitionComparatorTest.java
@@ -210,7 +210,7 @@ public class PartitionComparatorTest {
     }
 
     PartitionComparator comparator(TableSpec table, Iterator<Row> source, Iterator<Row> target) {
-        return new PartitionComparator(table, source, target);
+        return new PartitionComparator(table, source, target, RetryStrategyProvider.create(null));
     }
 
     List<String> names(String...names) {
diff --git a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
index 0484212..fd2926b 100644
--- a/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
+++ b/spark-job/src/test/java/org/apache/cassandra/diff/RangeComparatorTest.java
@@ -54,6 +54,7 @@ public class RangeComparatorTest {
     private BiConsumer<RangeStats, BigInteger> progressReporter = (r, t) -> journal.put(t, copyOf(r));
     private Set<BigInteger> comparedPartitions = new HashSet<>();
     private ComparisonExecutor executor = ComparisonExecutor.newExecutor(1, new MetricRegistry());
+    private RetryStrategyProvider mockRetryStrategyFactory = RetryStrategyProvider.create(null); // create a NoRetry provider
 
     @Test
     public void emptyRange() {
@@ -472,7 +473,7 @@ public class RangeComparatorTest {
 
     // yield a PartitionComparator which always concludes that partitions being compared are identical
     PartitionComparator alwaysMatch(PartitionKey key) {
-        return new PartitionComparator(null, null, null) {
+        return new PartitionComparator(null, null, null, mockRetryStrategyFactory) {
             public PartitionStats call() {
                 comparedPartitions.add(key.getTokenAsBigInteger());
                 return new PartitionStats();
@@ -482,7 +483,7 @@ public class RangeComparatorTest {
 
     // yield a PartitionComparator which always determines that the partitions have a row-level mismatch
     PartitionComparator rowMismatch(PartitionKey key) {
-        return new PartitionComparator(null, null,  null) {
+        return new PartitionComparator(null, null,  null, mockRetryStrategyFactory) {
             public PartitionStats call() {
                 comparedPartitions.add(key.getTokenAsBigInteger());
                 PartitionStats stats = new PartitionStats();
@@ -494,7 +495,7 @@ public class RangeComparatorTest {
 
     // yield a PartitionComparator which always determines that the partitions have a 10 mismatching values
     PartitionComparator valuesMismatch(PartitionKey key) {
-        return new PartitionComparator(null, null,  null) {
+        return new PartitionComparator(null, null,  null, mockRetryStrategyFactory) {
             public PartitionStats call() {
                 comparedPartitions.add(key.getTokenAsBigInteger());
                 PartitionStats stats = new PartitionStats();
@@ -522,7 +523,7 @@ public class RangeComparatorTest {
     Function<PartitionKey, PartitionComparator> throwDuringExecution(RuntimeException toThrow, long...throwAt) {
         return (key) -> {
             BigInteger t = key.getTokenAsBigInteger();
-            return new PartitionComparator(null, null, null) {
+            return new PartitionComparator(null, null, null, mockRetryStrategyFactory) {
                 public PartitionStats call() {
                     for (long shouldThrow : throwAt)
                         if (t.longValue() == shouldThrow)
@@ -548,7 +549,7 @@ public class RangeComparatorTest {
             BigInteger t = key.getTokenAsBigInteger();
             taskSubmissions.countDown();
 
-            return new PartitionComparator(null, null, null) {
+            return new PartitionComparator(null, null, null, mockRetryStrategyFactory) {
                 public PartitionStats call() {
                     if (!gateIter.hasNext())
                         fail("Expected a latch to control task progress");


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@cassandra.apache.org
For additional commands, e-mail: commits-help@cassandra.apache.org