You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/03/30 23:35:10 UTC

[spark] branch master updated: [SPARK-38694][TESTS] Simplify Java UT code with Junit `assertThrows` Api

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ef8fb9b  [SPARK-38694][TESTS] Simplify Java UT code with Junit `assertThrows` Api
ef8fb9b is described below

commit ef8fb9b9d84b6adfe5a4e03b6e775e709d624144
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Wed Mar 30 18:32:37 2022 -0500

    [SPARK-38694][TESTS] Simplify Java UT code with Junit `assertThrows` Api
    
    ### What changes were proposed in this pull request?
    There are some code patterns in Spark Java UTs:
    
    ```java
    Test
      public void testAuthReplay() throws Exception {
        try {
          doSomeOperation();
          fail("Should have failed");
        } catch (Exception e) {
          assertTrue(doExceptionCheck(e));
        }
      }
    ```
    or
    ```java
      Test(expected = SomeException.class)
      public void testAuthReplay() throws Exception {
        try {
          doSomeOperation();
          fail("Should have failed");
        } catch (Exception e) {
          assertTrue(doExceptionCheck(e));
          throw e;
        }
      }
    ```
    This pr  use Junit `assertThrows` Api to simplify the similar patterns.
    
    ### Why are the changes needed?
    Simplify code.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Pass GA
    
    Closes #36008 from LuciferYang/SPARK-38694.
    
    Authored-by: yangjie01 <ya...@baidu.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../spark/util/kvstore/InMemoryStoreSuite.java     |  21 +-
 .../apache/spark/util/kvstore/LevelDBSuite.java    |  21 +-
 .../apache/spark/util/kvstore/RocksDBSuite.java    |  21 +-
 .../spark/network/crypto/AuthIntegrationSuite.java |  39 +-
 .../spark/network/crypto/TransportCipherSuite.java |  21 +-
 .../apache/spark/network/sasl/SparkSaslSuite.java  |  41 +-
 .../server/OneForOneStreamManagerSuite.java        |  23 +-
 .../spark/network/sasl/SaslIntegrationSuite.java   |  37 +-
 .../network/shuffle/ExternalBlockHandlerSuite.java |  14 +-
 .../shuffle/ExternalShuffleBlockResolverSuite.java |  17 +-
 .../shuffle/ExternalShuffleSecuritySuite.java      |  16 +-
 .../shuffle/OneForOneBlockFetcherSuite.java        |  14 +-
 .../shuffle/RemoteBlockPushResolverSuite.java      | 464 +++++++++------------
 .../apache/spark/unsafe/types/UTF8StringSuite.java |   8 +-
 .../apache/spark/launcher/SparkLauncherSuite.java  |  15 +-
 .../shuffle/sort/PackedRecordPointerSuite.java     |  14 +-
 .../unsafe/map/AbstractBytesToBytesMapSuite.java   |  40 +-
 .../java/test/org/apache/spark/JavaAPISuite.java   |  16 +-
 .../spark/launcher/CommandBuilderUtilsSuite.java   |   7 +-
 .../apache/spark/launcher/LauncherServerSuite.java |  14 +-
 .../JavaRandomForestClassifierSuite.java           |   8 +-
 .../regression/JavaRandomForestRegressorSuite.java |   8 +-
 .../spark/ml/util/JavaDefaultReadWriteSuite.java   |   8 +-
 .../expressions/RowBasedKeyValueBatchSuite.java    |  60 +--
 .../spark/sql/JavaBeanDeserializationSuite.java    |  15 +-
 .../spark/sql/JavaColumnExpressionSuite.java       |  16 +-
 26 files changed, 317 insertions(+), 661 deletions(-)

diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java
index 198b6e8..b2acd1a 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/InMemoryStoreSuite.java
@@ -34,24 +34,14 @@ public class InMemoryStoreSuite {
     t.id = "id";
     t.name = "name";
 
-    try {
-      store.read(CustomType1.class, t.key);
-      fail("Expected exception for non-existent object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> store.read(CustomType1.class, t.key));
 
     store.write(t);
     assertEquals(t, store.read(t.getClass(), t.key));
     assertEquals(1L, store.count(t.getClass()));
 
     store.delete(t.getClass(), t.key);
-    try {
-      store.read(t.getClass(), t.key);
-      fail("Expected exception for deleted object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> store.read(t.getClass(), t.key));
   }
 
   @Test
@@ -78,12 +68,7 @@ public class InMemoryStoreSuite {
     store.delete(t1.getClass(), t1.key);
     assertEquals(t2, store.read(t2.getClass(), t2.key));
     store.delete(t2.getClass(), t2.key);
-    try {
-      store.read(t2.getClass(), t2.key);
-      fail("Expected exception for deleted object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> store.read(t2.getClass(), t2.key));
   }
 
   @Test
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java
index c43c9b1..a7a2148 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/LevelDBSuite.java
@@ -71,36 +71,21 @@ public class LevelDBSuite {
     db.close();
     db = null;
 
-    try {
-      db = new LevelDB(dbpath);
-      fail("Should have failed version check.");
-    } catch (UnsupportedStoreVersionException e) {
-      // Expected.
-    }
+    assertThrows(UnsupportedStoreVersionException.class, () -> db = new LevelDB(dbpath));
   }
 
   @Test
   public void testObjectWriteReadDelete() throws Exception {
     CustomType1 t = createCustomType1(1);
 
-    try {
-      db.read(CustomType1.class, t.key);
-      fail("Expected exception for non-existent object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> db.read(CustomType1.class, t.key));
 
     db.write(t);
     assertEquals(t, db.read(t.getClass(), t.key));
     assertEquals(1L, db.count(t.getClass()));
 
     db.delete(t.getClass(), t.key);
-    try {
-      db.read(t.getClass(), t.key);
-      fail("Expected exception for deleted object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> db.read(t.getClass(), t.key));
 
     // Look into the actual DB and make sure that all the keys related to the type have been
     // removed.
diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java
index 04463ee..8112cbf 100644
--- a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java
+++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/RocksDBSuite.java
@@ -69,36 +69,21 @@ public class RocksDBSuite {
     db.close();
     db = null;
 
-    try {
-      db = new RocksDB(dbpath);
-      fail("Should have failed version check.");
-    } catch (UnsupportedStoreVersionException e) {
-      // Expected.
-    }
+    assertThrows(UnsupportedStoreVersionException.class, () -> db = new RocksDB(dbpath));
   }
 
   @Test
   public void testObjectWriteReadDelete() throws Exception {
     CustomType1 t = createCustomType1(1);
 
-    try {
-      db.read(CustomType1.class, t.key);
-      fail("Expected exception for non-existent object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> db.read(CustomType1.class, t.key));
 
     db.write(t);
     assertEquals(t, db.read(t.getClass(), t.key));
     assertEquals(1L, db.count(t.getClass()));
 
     db.delete(t.getClass(), t.key);
-    try {
-      db.read(t.getClass(), t.key);
-      fail("Expected exception for deleted object.");
-    } catch (NoSuchElementException nsee) {
-      // Expected.
-    }
+    assertThrows(NoSuchElementException.class, () -> db.read(t.getClass(), t.key));
 
     // Look into the actual DB and make sure that all the keys related to the type have been
     // removed.
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
index d4bf28e..62ccccb 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/AuthIntegrationSuite.java
@@ -72,13 +72,9 @@ public class AuthIntegrationSuite {
     ctx = new AuthTestCtx();
     ctx.createServer("server");
 
-    try {
-      ctx.createClient("client");
-      fail("Should have failed to create client.");
-    } catch (Exception e) {
-      assertFalse(ctx.authRpcHandler.isAuthenticated());
-      assertFalse(ctx.serverChannel.isActive());
-    }
+    assertThrows(Exception.class, () -> ctx.createClient("client"));
+    assertFalse(ctx.authRpcHandler.isAuthenticated());
+    assertFalse(ctx.serverChannel.isActive());
   }
 
   @Test
@@ -115,13 +111,9 @@ public class AuthIntegrationSuite {
 
     assertNotNull(ctx.client.getChannel().pipeline()
       .remove(TransportCipher.ENCRYPTION_HANDLER_NAME));
-
-    try {
-      ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
-      fail("Should have failed unencrypted RPC.");
-    } catch (Exception e) {
-      assertTrue(ctx.authRpcHandler.isAuthenticated());
-    }
+    assertThrows(Exception.class,
+      () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000));
+    assertTrue(ctx.authRpcHandler.isAuthenticated());
   }
 
   @Test
@@ -147,17 +139,14 @@ public class AuthIntegrationSuite {
     ctx.createServer("secret");
     ctx.createClient("secret");
 
-    try {
-      ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000);
-      fail("Should have failed unencrypted RPC.");
-    } catch (Exception e) {
-      assertTrue(ctx.authRpcHandler.isAuthenticated());
-      assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
-      // Verify we receive the complete error message
-      int messageStart = e.getMessage().indexOf("DDDDD");
-      int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5;
-      assertEquals(testErrorMessageLength, messageEnd - messageStart);
-    }
+    Exception e = assertThrows(Exception.class,
+      () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"), 5000));
+    assertTrue(ctx.authRpcHandler.isAuthenticated());
+    assertTrue(e.getMessage() + " is not an expected error", e.getMessage().contains("DDDDD"));
+    // Verify we receive the complete error message
+    int messageStart = e.getMessage().indexOf("DDDDD");
+    int messageEnd = e.getMessage().lastIndexOf("DDDDD") + 5;
+    assertEquals(testErrorMessageLength, messageEnd - messageStart);
   }
 
   private static class AuthTestCtx {
diff --git a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
index cff115d..cde5c1c 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/crypto/TransportCipherSuite.java
@@ -32,7 +32,7 @@ import org.junit.Test;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.fail;
+import static org.junit.Assert.assertThrows;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyInt;
 import static org.mockito.Mockito.mock;
@@ -67,21 +67,12 @@ public class TransportCipherSuite {
     ByteBuf buffer = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
     ByteBuf buffer2 = Unpooled.wrappedBuffer(new byte[] { 1, 2 });
 
-    try {
-      channel.writeInbound(buffer);
-      fail("Should have raised InternalError");
-    } catch (InternalError expected) {
-      // expected
-      assertEquals(0, buffer.refCnt());
-    }
+    assertThrows(InternalError.class, () -> channel.writeInbound(buffer));
+    assertEquals(0, buffer.refCnt());
 
-    try {
-      channel.writeInbound(buffer2);
-      fail("Should have raised an exception");
-    } catch (Throwable expected) {
-      assertEquals(expected.getClass(), IOException.class);
-      assertEquals(0, buffer2.refCnt());
-    }
+    Throwable expected = assertThrows(Throwable.class, () -> channel.writeInbound(buffer2));
+    assertEquals(expected.getClass(), IOException.class);
+    assertEquals(0, buffer2.refCnt());
 
     // Simulate closing the connection
     assertFalse(channel.finish());
diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
index 2b0bcca..6096cd3 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java
@@ -147,13 +147,11 @@ public class SparkSaslSuite {
       .when(rpcHandler)
       .receive(any(TransportClient.class), any(ByteBuffer.class), any(RpcResponseCallback.class));
 
-    SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false);
-    try {
+    try (SaslTestCtx ctx = new SaslTestCtx(rpcHandler, encrypt, false)) {
       ByteBuffer response = ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
         TimeUnit.SECONDS.toMillis(10));
       assertEquals("Pong", JavaUtils.bytesToString(response));
     } finally {
-      ctx.close();
       // There should be 2 terminated events; one for the client, one for the server.
       Throwable error = null;
       long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS);
@@ -301,19 +299,11 @@ public class SparkSaslSuite {
   }
 
   @Test
-  public void testServerAlwaysEncrypt() throws Exception {
-    SaslTestCtx ctx = null;
-    try {
-      ctx = new SaslTestCtx(mock(RpcHandler.class), false, false,
-        ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true"));
-      fail("Should have failed to connect without encryption.");
-    } catch (Exception e) {
-      assertTrue(e.getCause() instanceof SaslException);
-    } finally {
-      if (ctx != null) {
-        ctx.close();
-      }
-    }
+  public void testServerAlwaysEncrypt() {
+    Exception re = assertThrows(Exception.class,
+      () -> new SaslTestCtx(mock(RpcHandler.class), false, false,
+              ImmutableMap.of("spark.network.sasl.serverAlwaysEncrypt", "true")));
+    assertTrue(re.getCause() instanceof SaslException);
   }
 
   @Test
@@ -321,18 +311,11 @@ public class SparkSaslSuite {
     // This test sets up an encrypted connection but then, using a client bootstrap, removes
     // the encryption handler from the client side. This should cause the server to not be
     // able to understand RPCs sent to it and thus close the connection.
-    SaslTestCtx ctx = null;
-    try {
-      ctx = new SaslTestCtx(mock(RpcHandler.class), true, true);
-      ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
-        TimeUnit.SECONDS.toMillis(10));
-      fail("Should have failed to send RPC to server.");
-    } catch (Exception e) {
+    try (SaslTestCtx ctx = new SaslTestCtx(mock(RpcHandler.class), true, true)) {
+      Exception e = assertThrows(Exception.class,
+        () -> ctx.client.sendRpcSync(JavaUtils.stringToBytes("Ping"),
+                TimeUnit.SECONDS.toMillis(10)));
       assertFalse(e.getCause() instanceof TimeoutException);
-    } finally {
-      if (ctx != null) {
-        ctx.close();
-      }
     }
   }
 
@@ -362,7 +345,7 @@ public class SparkSaslSuite {
     }
   }
 
-  private static class SaslTestCtx {
+  private static class SaslTestCtx implements AutoCloseable {
 
     final TransportClient client;
     final TransportServer server;
@@ -423,7 +406,7 @@ public class SparkSaslSuite {
       this.disableClientEncryption = disableClientEncryption;
     }
 
-    void close() {
+    public void close() {
       if (!disableClientEncryption) {
         assertEquals(encrypt, checker.foundEncryptionHandler);
       }
diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
index 634b40e..b65daaf 100644
--- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
+++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java
@@ -118,21 +118,12 @@ public class OneForOneStreamManagerSuite {
 
     Assert.assertEquals(2, manager.numStreamStates());
 
-    try {
-      manager.connectionTerminated(dummyChannel);
-      Assert.fail("connectionTerminated should throw exception when fails to release all buffers");
-
-    } catch (RuntimeException e) {
-
-      Mockito.verify(buffers, Mockito.times(1)).hasNext();
-      Mockito.verify(buffers, Mockito.times(1)).next();
-
-      Mockito.verify(buffers2, Mockito.times(2)).hasNext();
-      Mockito.verify(buffers2, Mockito.times(2)).next();
-
-      Mockito.verify(mockManagedBuffer, Mockito.times(1)).release();
-
-      Assert.assertEquals(0, manager.numStreamStates());
-    }
+    Assert.assertThrows(RuntimeException.class, () -> manager.connectionTerminated(dummyChannel));
+    Mockito.verify(buffers, Mockito.times(1)).hasNext();
+    Mockito.verify(buffers, Mockito.times(1)).next();
+    Mockito.verify(buffers2, Mockito.times(2)).hasNext();
+    Mockito.verify(buffers2, Mockito.times(2)).next();
+    Mockito.verify(mockManagedBuffer, Mockito.times(1)).release();
+    Assert.assertEquals(0, manager.numStreamStates());
   }
 }
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
index 96dfc3b..ec749cb 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java
@@ -108,13 +108,10 @@ public class SaslIntegrationSuite {
     clientFactory = context.createClientFactory(
         Arrays.asList(new SaslClientBootstrap(conf, "unknown-app", badKeyHolder)));
 
-    try {
-      // Bootstrap should fail on startup.
-      clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-      fail("Connection should have failed.");
-    } catch (Exception e) {
-      assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
-    }
+    // Bootstrap should fail on startup.
+    Exception e = assertThrows(Exception.class,
+      () -> clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()));
+    assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
   }
 
   @Test
@@ -122,20 +119,14 @@ public class SaslIntegrationSuite {
     clientFactory = context.createClientFactory(new ArrayList<>());
 
     TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-    try {
-      client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS);
-      fail("Should have failed");
-    } catch (Exception e) {
-      assertTrue(e.getMessage(), e.getMessage().contains("Expected SaslMessage"));
-    }
-
-    try {
-      // Guessing the right tag byte doesn't magically get you in...
-      client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS);
-      fail("Should have failed");
-    } catch (Exception e) {
-      assertTrue(e.getMessage(), e.getMessage().contains("java.lang.IndexOutOfBoundsException"));
-    }
+    Exception e1 = assertThrows(Exception.class,
+      () -> client.sendRpcSync(ByteBuffer.allocate(13), TIMEOUT_MS));
+    assertTrue(e1.getMessage(), e1.getMessage().contains("Expected SaslMessage"));
+
+    // Guessing the right tag byte doesn't magically get you in...
+    Exception e2 = assertThrows(Exception.class,
+      () -> client.sendRpcSync(ByteBuffer.wrap(new byte[] { (byte) 0xEA }), TIMEOUT_MS));
+    assertTrue(e2.getMessage(), e2.getMessage().contains("java.lang.IndexOutOfBoundsException"));
   }
 
   @Test
@@ -145,8 +136,8 @@ public class SaslIntegrationSuite {
       clientFactory = context.createClientFactory(
           Arrays.asList(new SaslClientBootstrap(conf, "app-1", secretKeyHolder)));
       try (TransportServer server = context.createServer()) {
-        clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
-      } catch (Exception e) {
+        Exception e = assertThrows(Exception.class,
+          () -> clientFactory.createClient(TestUtils.getLocalHost(), server.getPort()));
         assertTrue(e.getMessage(), e.getMessage().contains("Digest-challenge format violation"));
       }
     }
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
index d45cbd5..14896c8 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java
@@ -332,21 +332,11 @@ public class ExternalBlockHandlerSuite {
     RpcResponseCallback callback = mock(RpcResponseCallback.class);
 
     ByteBuffer unserializableMsg = ByteBuffer.wrap(new byte[] { 0x12, 0x34, 0x56 });
-    try {
-      handler.receive(client, unserializableMsg, callback);
-      fail("Should have thrown");
-    } catch (Exception e) {
-      // pass
-    }
+    assertThrows(Exception.class, () -> handler.receive(client, unserializableMsg, callback));
 
     ByteBuffer unexpectedMsg = new UploadBlock("a", "e", "b", new byte[1],
       new byte[2]).toByteBuffer();
-    try {
-      handler.receive(client, unexpectedMsg, callback);
-      fail("Should have thrown");
-    } catch (UnsupportedOperationException e) {
-      // pass
-    }
+    assertThrows(Exception.class, () -> handler.receive(client, unexpectedMsg, callback));
 
     verify(callback, never()).onSuccess(any(ByteBuffer.class));
     verify(callback, never()).onFailure(any(Throwable.class));
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
index 04d4bdf..ec195e8 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java
@@ -64,22 +64,15 @@ public class ExternalShuffleBlockResolverSuite {
   public void testBadRequests() throws IOException {
     ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null);
     // Unregistered executor
-    try {
-      resolver.getBlockData("app0", "exec1", 1, 1, 0);
-      fail("Should have failed");
-    } catch (RuntimeException e) {
-      assertTrue("Bad error message: " + e, e.getMessage().contains("not registered"));
-    }
+    RuntimeException e = assertThrows(RuntimeException.class,
+      () -> resolver.getBlockData("app0", "exec1", 1, 1, 0));
+    assertTrue("Bad error message: " + e, e.getMessage().contains("not registered"));
 
     // Nonexistent shuffle block
     resolver.registerExecutor("app0", "exec3",
       dataContext.createExecutorInfo(SORT_MANAGER));
-    try {
-      resolver.getBlockData("app0", "exec3", 1, 1, 0);
-      fail("Should have failed");
-    } catch (Exception e) {
-      // pass
-    }
+    assertThrows(Exception.class,
+      () -> resolver.getBlockData("app0", "exec3", 1, 1, 0));
   }
 
   @Test
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
index 883e643..c52ac31 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleSecuritySuite.java
@@ -70,20 +70,16 @@ public class ExternalShuffleSecuritySuite {
 
   @Test
   public void testBadAppId() {
-    try {
-      validate("wrong-app-id", "secret", false);
-    } catch (Exception e) {
-      assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
-    }
+    Exception e = assertThrows(Exception.class,
+      () -> validate("wrong-app-id", "secret", false));
+    assertTrue(e.getMessage(), e.getMessage().contains("Wrong appId!"));
   }
 
   @Test
   public void testBadSecret() {
-    try {
-      validate("my-app-id", "bad-secret", false);
-    } catch (Exception e) {
-      assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
-    }
+    Exception e = assertThrows(Exception.class,
+      () -> validate("my-app-id", "bad-secret", false));
+    assertTrue(e.getMessage(), e.getMessage().contains("Mismatched response"));
   }
 
   @Test
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
index cc4640d..5f3d3c8 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java
@@ -189,16 +189,10 @@ public class OneForOneBlockFetcherSuite {
 
   @Test
   public void testEmptyBlockFetch() {
-    try {
-      fetchBlocks(
-        Maps.newLinkedHashMap(),
-        new String[] {},
-        new OpenBlocks("app-id", "exec-id", new String[] {}),
-        conf);
-      fail();
-    } catch (IllegalArgumentException e) {
-      assertEquals("Zero-sized blockIds array", e.getMessage());
-    }
+    IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
+      () -> fetchBlocks(Maps.newLinkedHashMap(), new String[] {},
+        new OpenBlocks("app-id", "exec-id", new String[] {}), conf));
+    assertEquals("Zero-sized blockIds array", e.getMessage());
   }
 
   @Test
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java
index f76afae..20aae7c 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java
@@ -32,7 +32,6 @@ import java.util.Map;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.ThreadLocalRandom;
 
-import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableMap;
 
 import org.apache.commons.io.FileUtils;
@@ -120,14 +119,11 @@ public class RemoteBlockPushResolverSuite {
     assertTrue(errorHandler.shouldLogError(new Throwable()));
   }
 
-  @Test(expected = RuntimeException.class)
+  @Test
   public void testNoIndexFile() {
-    try {
-      pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
-    } catch (Throwable t) {
-      assertTrue(t.getMessage().startsWith("Merged shuffle index file"));
-      Throwables.propagate(t);
-    }
+    RuntimeException re = assertThrows(RuntimeException.class,
+      () -> pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0));
+    assertTrue(re.getMessage().startsWith("Merged shuffle index file"));
   }
 
   @Test
@@ -303,7 +299,7 @@ public class RemoteBlockPushResolverSuite {
     validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
   }
 
-  @Test(expected = BlockPushNonFatalFailure.class)
+  @Test
   public void testBlockReceivedAfterMergeFinalize() throws IOException {
     ByteBuffer[] blocks = new ByteBuffer[]{
       ByteBuffer.wrap(new byte[4]),
@@ -319,18 +315,15 @@ public class RemoteBlockPushResolverSuite {
     StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream(
       new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 1, 0, 0));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4]));
-    try {
-      stream1.onComplete(stream1.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream1.getID());
-      MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
-      validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
-      throw e;
-    }
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream1.onComplete(stream1.getID()));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream1.getID());
+    MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
+    validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}});
   }
 
   @Test
@@ -367,7 +360,7 @@ public class RemoteBlockPushResolverSuite {
     assertArrayEquals(expectedBytes, mb.nioByteBuffer().array());
   }
 
-  @Test(expected = BlockPushNonFatalFailure.class)
+  @Test
   public void testCollision() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
@@ -379,19 +372,16 @@ public class RemoteBlockPushResolverSuite {
     // This should be deferred
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5]));
     // Since stream2 didn't get any opportunity it will throw couldn't find opportunity error
-    try {
-      stream2.onComplete(stream2.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream2.getID());
-      throw e;
-    }
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream2.onComplete(stream2.getID()));
+    BlockPushReturnCode errorCode =
+            (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(),
+            errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream2.getID());
   }
 
-  @Test(expected = BlockPushNonFatalFailure.class)
+  @Test
   public void testFailureInAStreamDoesNotInterfereWithStreamWhichIsWriting() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
@@ -408,17 +398,13 @@ public class RemoteBlockPushResolverSuite {
     // This should be deferred
     stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[5]));
     // Since this stream didn't get any opportunity it will throw couldn't find opportunity error
-    BlockPushNonFatalFailure failedEx = null;
-    try {
-      stream3.onComplete(stream3.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream3.getID());
-      failedEx = e;
-    }
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream3.onComplete(stream3.getID()));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream3.getID());
     // stream 1 now completes
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onComplete(stream1.getID());
@@ -426,12 +412,9 @@ public class RemoteBlockPushResolverSuite {
     pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 0));
     MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
     validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}});
-    if (failedEx != null) {
-      throw failedEx;
-    }
   }
 
-  @Test(expected = IllegalArgumentException.class)
+  @Test
   public void testUpdateLocalDirsOnlyOnce() throws IOException {
     String testApp = "updateLocalDirsOnlyOnceTest";
     Path[] activeLocalDirs = createLocalDirs(1);
@@ -449,32 +432,25 @@ public class RemoteBlockPushResolverSuite {
     assertTrue(pushResolver.getMergedBlockDirs(testApp)[0].contains(
       activeLocalDirs[0].toFile().getPath()));
     removeApplication(testApp);
-    try {
-      pushResolver.getMergedBlockDirs(testApp);
-    } catch (IllegalArgumentException e) {
-      assertEquals(e.getMessage(),
-        "application " + testApp + " is not registered or NM was restarted.");
-      throw e;
-    }
+    IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
+      () -> pushResolver.getMergedBlockDirs(testApp));
+    assertEquals(e.getMessage(),
+      "application " + testApp + " is not registered or NM was restarted.");
   }
 
-  @Test(expected = IllegalArgumentException.class)
+  @Test
   public void testExecutorRegisterWithInvalidJsonForPushShuffle() throws IOException {
     String testApp = "executorRegisterWithInvalidShuffleManagerMeta";
     Path[] activeLocalDirs = createLocalDirs(1);
-    try {
-      registerExecutor(testApp, prepareLocalDirs(activeLocalDirs, MERGE_DIRECTORY),
-        INVALID_MERGE_DIRECTORY_META);
-    } catch (IllegalArgumentException re) {
-      assertEquals(
-        "Failed to get the merge directory information from the shuffleManagerMeta " +
-          "shuffleManager:{\"mergeDirInvalid\": \"merge_manager_2\", \"attemptId\": \"2\"} in " +
-          "executor registration message", re.getMessage());
-      throw re;
-    }
+    IllegalArgumentException re = assertThrows(IllegalArgumentException.class,
+      () -> registerExecutor(testApp, prepareLocalDirs(activeLocalDirs, MERGE_DIRECTORY),
+        INVALID_MERGE_DIRECTORY_META));
+    assertEquals("Failed to get the merge directory information from the shuffleManagerMeta " +
+      "shuffleManager:{\"mergeDirInvalid\": \"merge_manager_2\", \"attemptId\": \"2\"} in " +
+      "executor registration message", re.getMessage());
   }
 
-  @Test(expected = IllegalArgumentException.class)
+  @Test
   public void testExecutorRegistrationFromTwoAppAttempts() throws IOException {
     String testApp = "testExecutorRegistrationFromTwoAppAttempts";
     Path[] attempt1LocalDirs = createLocalDirs(1);
@@ -502,13 +478,10 @@ public class RemoteBlockPushResolverSuite {
     assertTrue(pushResolver.getMergedBlockDirs(testApp)[0].contains(
       attempt2LocalDirs[0].toFile().getPath()));
     removeApplication(testApp);
-    try {
-      pushResolver.getMergedBlockDirs(testApp);
-    } catch (IllegalArgumentException e) {
-      assertEquals(e.getMessage(),
-        "application " + testApp + " is not registered or NM was restarted.");
-      throw e;
-    }
+    IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
+      () -> pushResolver.getMergedBlockDirs(testApp));
+    assertEquals(e.getMessage(),
+      "application " + testApp + " is not registered or NM was restarted.");
   }
 
   @Test
@@ -673,7 +646,7 @@ public class RemoteBlockPushResolverSuite {
     validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {4, 5}, new int[][] {{0}, {1}});
   }
 
-  @Test(expected = IllegalStateException.class)
+  @Test
   public void testIOExceptionsExceededThreshold() throws IOException {
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
@@ -687,28 +660,23 @@ public class RemoteBlockPushResolverSuite {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
           new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
-      try {
-        callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[2]));
-      } catch (IOException ioe) {
-        // this will throw IOException so the client can retry.
-        callback1.onFailure(callback1.getID(), ioe);
-      }
+      IOException ioe = assertThrows(IOException.class,
+        () -> callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[2])));
+      // this will throw IOException so the client can retry.
+      callback1.onFailure(callback1.getID(), ioe);
     }
     assertEquals(4, partitionInfo.getNumIOExceptions());
     // After 4 IOException, the server will respond with IOExceptions exceeded threshold
-    try {
-      RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
-        (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0));
-      callback2.onData(callback.getID(), ByteBuffer.wrap(new byte[1]));
-    } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0",
-        t.getMessage());
-      throw t;
-    }
+    RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
+      (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0));
+    IllegalStateException e = assertThrows(IllegalStateException.class,
+      () -> callback2.onData(callback.getID(), ByteBuffer.wrap(new byte[1])));
+    assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0",
+      e.getMessage());
   }
 
-  @Test(expected = IllegalStateException.class)
+  @Test
   public void testIOExceptionsDuringMetaUpdateIncreasesExceptionCount() throws IOException {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
@@ -730,37 +698,29 @@ public class RemoteBlockPushResolverSuite {
     assertEquals(4, partitionInfo.getNumIOExceptions());
     // After 4 IOException, the server will respond with IOExceptions exceeded threshold for any
     // new request for this partition.
-    try {
-      RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
+    RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
       (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
         new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 5, 0, 0));
-      callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[4]));
-      callback2.onComplete(callback2.getID());
-    } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0",
-        t.getMessage());
-      throw t;
-    }
+    callback2.onData(callback2.getID(), ByteBuffer.wrap(new byte[4]));
+    IllegalStateException e = assertThrows(IllegalStateException.class,
+      () -> callback2.onComplete(callback2.getID()));
+    assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_5_0",
+      e.getMessage());
   }
 
-  @Test(expected = IllegalStateException.class)
-  public void testRequestForAbortedShufflePartitionThrowsException() {
-    try {
-      testIOExceptionsDuringMetaUpdateIncreasesExceptionCount();
-    } catch (Throwable t) {
-      // No more blocks can be merged to this partition.
-    }
-    try {
-      pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 10, 0, 0));
-    } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_10_0",
-        t.getMessage());
-      throw t;
-    }
+  @Test
+  public void testRequestForAbortedShufflePartitionThrowsException() throws IOException {
+    // No more blocks can be merged to this partition.
+    testIOExceptionsDuringMetaUpdateIncreasesExceptionCount();
+
+    IllegalStateException t = assertThrows(IllegalStateException.class,
+      () -> pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 10, 0, 0)));
+    assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_10_0",
+      t.getMessage());
   }
 
-  @Test(expected = IllegalStateException.class)
+  @Test
   public void testPendingBlockIsAbortedImmediately() throws IOException {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
@@ -773,27 +733,25 @@ public class RemoteBlockPushResolverSuite {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
           new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
-      try {
-        callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
+      callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
+      if (i < 5) {
         // This will complete without any exceptions but the exception count is increased.
         callback1.onComplete(callback1.getID());
-      } catch (Throwable t) {
+      } else {
+        Throwable t = assertThrows(Throwable.class, () -> callback1.onComplete(callback1.getID()));
         callback1.onFailure(callback1.getID(), t);
       }
     }
     assertEquals(5, partitionInfo.getNumIOExceptions());
     // The server will respond with IOExceptions exceeded threshold for any additional attempts
     // to write.
-    try {
-      callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]));
-    } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0",
-        t.getMessage());
-      throw t;
-    }
+    IllegalStateException e = assertThrows(IllegalStateException.class,
+      () -> callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4])));
+    assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0",
+      e.getMessage());
   }
 
-  @Test(expected = IllegalStateException.class)
+  @Test
   public void testWritingPendingBufsIsAbortedImmediatelyDuringComplete() throws IOException {
     useTestFiles(true, false);
     RemoteBlockPushResolver.PushBlockStreamCallback callback =
@@ -806,13 +764,9 @@ public class RemoteBlockPushResolverSuite {
       RemoteBlockPushResolver.PushBlockStreamCallback callback1 =
         (RemoteBlockPushResolver.PushBlockStreamCallback) pushResolver.receiveBlockDataAsStream(
           new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, i, 0, 0));
-      try {
-        callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
-        // This will complete without any exceptions but the exception count is increased.
-        callback1.onComplete(callback1.getID());
-      } catch (Throwable t) {
-        callback1.onFailure(callback1.getID(), t);
-      }
+      callback1.onData(callback1.getID(), ByteBuffer.wrap(new byte[5]));
+      // This will complete without any exceptions but the exception count is increased.
+      callback1.onComplete(callback1.getID());
     }
     assertEquals(4, partitionInfo.getNumIOExceptions());
     RemoteBlockPushResolver.PushBlockStreamCallback callback2 =
@@ -822,22 +776,16 @@ public class RemoteBlockPushResolverSuite {
     // This is deferred
     callback.onData(callback.getID(), ByteBuffer.wrap(new byte[4]));
     // Callback2 completes which will throw another exception.
-    try {
-      callback2.onComplete(callback2.getID());
-    } catch (Throwable t) {
-      callback2.onFailure(callback2.getID(), t);
-    }
+    Throwable t = assertThrows(Throwable.class, () -> callback2.onComplete(callback2.getID()));
+    callback2.onFailure(callback2.getID(), t);
     assertEquals(5, partitionInfo.getNumIOExceptions());
     // Restore index file so that any further writes to it are successful and any exceptions are
     // due to IOExceptions exceeding threshold.
     testIndexFile.restore();
-    try {
-      callback.onComplete(callback.getID());
-    } catch (Throwable t) {
-      assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0",
-        t.getMessage());
-      throw t;
-    }
+    IllegalStateException ie = assertThrows(IllegalStateException.class,
+      () -> callback.onComplete(callback.getID()));
+    assertEquals("IOExceptions exceeded the threshold when merging shufflePush_0_0_0_0",
+      ie.getMessage());
   }
 
   @Test
@@ -894,7 +842,7 @@ public class RemoteBlockPushResolverSuite {
     removeApplication(TEST_APP);
   }
 
-  @Test(expected = BlockPushNonFatalFailure.class)
+  @Test
   public void testFailureAfterDuplicateBlockDoesNotInterfereActiveStream() throws IOException {
     StreamCallbackWithID stream1 =
       pushResolver.receiveBlockDataAsStream(
@@ -918,17 +866,13 @@ public class RemoteBlockPushResolverSuite {
         new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 0, 0, 2, 0, 0));
     // This should be deferred as stream 2 is still the active stream
     stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[2]));
-    BlockPushNonFatalFailure failedEx = null;
-    try {
-      stream3.onComplete(stream3.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream3.getID());
-      failedEx = e;
-    }
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream3.onComplete(stream3.getID()));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.BLOCK_APPEND_COLLISION_DETECTED.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream3.getID());
     // Stream 2 writes more and completes
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[4]));
     stream2.onComplete(stream2.getID());
@@ -936,12 +880,9 @@ public class RemoteBlockPushResolverSuite {
     MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
     validateChunks(TEST_APP, 0, 0, 0, blockMeta, new int[] {11}, new int[][] {{0, 1}});
     removeApplication(TEST_APP);
-    if (failedEx != null) {
-      throw failedEx;
-    }
   }
 
-  @Test(expected = BlockPushNonFatalFailure.class)
+  @Test
   public void testPushBlockFromPreviousAttemptIsRejected()
       throws IOException, InterruptedException {
     Semaphore closed = new Semaphore(0);
@@ -997,22 +938,19 @@ public class RemoteBlockPushResolverSuite {
       assertFalse(partitionInfo.getMetaFile().getChannel().isOpen());
       assertFalse(partitionInfo.getIndexFile().getChannel().isOpen());
     }
-    try {
-      pushResolver.receiveBlockDataAsStream(
-        new PushBlockStream(testApp, 1, 0, 0, 1, 0, 0));
-    } catch (BlockPushNonFatalFailure re) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(re.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_OLD_ATTEMPT_PUSH.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream2.getID());
-      throw re;
-    }
+    BlockPushNonFatalFailure re = assertThrows(BlockPushNonFatalFailure.class,
+      () -> pushResolver.receiveBlockDataAsStream(
+        new PushBlockStream(testApp, 1, 0, 0, 1, 0, 0)));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(re.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_OLD_ATTEMPT_PUSH.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream2.getID());
   }
 
-  @Test(expected = IllegalArgumentException.class)
+  @Test
   public void testFinalizeShuffleMergeFromPreviousAttemptIsAborted()
-    throws IOException, InterruptedException {
+    throws IOException {
     String testApp = "testFinalizeShuffleMergeFromPreviousAttemptIsAborted";
     Path[] attempt1LocalDirs = createLocalDirs(1);
     registerExecutor(testApp,
@@ -1032,15 +970,13 @@ public class RemoteBlockPushResolverSuite {
     registerExecutor(testApp,
       prepareLocalDirs(attempt2LocalDirs, MERGE_DIRECTORY + "_" + ATTEMPT_ID_2),
       MERGE_DIRECTORY_META_2);
-    try {
-      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0, 0));
-    } catch (IllegalArgumentException e) {
-      assertEquals(e.getMessage(),
-        String.format("The attempt id %s in this FinalizeShuffleMerge message does not " +
-            "match with the current attempt id %s stored in shuffle service for application %s",
-          ATTEMPT_ID_1, ATTEMPT_ID_2, testApp));
-      throw e;
-    }
+    IllegalArgumentException e = assertThrows(IllegalArgumentException.class,
+      () -> pushResolver.finalizeShuffleMerge(
+              new FinalizeShuffleMerge(testApp, ATTEMPT_ID_1, 0, 0)));
+    assertEquals(e.getMessage(),
+      String.format("The attempt id %s in this FinalizeShuffleMerge message does not " +
+        "match with the current attempt id %s stored in shuffle service for application %s",
+        ATTEMPT_ID_1, ATTEMPT_ID_2, testApp));
   }
 
   @Test(expected = ClosedChannelException.class)
@@ -1095,16 +1031,13 @@ public class RemoteBlockPushResolverSuite {
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
-    try {
-      // stream 1 push should be rejected as it is from an older shuffleMergeId
-      stream1.onComplete(stream1.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream1.getID());
-    }
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream1.onComplete(stream1.getID()));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream1.getID());
     // stream 2 now completes
     stream2.onComplete(stream2.getID());
     pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2));
@@ -1124,25 +1057,22 @@ public class RemoteBlockPushResolverSuite {
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
-    try {
-      // stream 1 push should be rejected as it is from an older shuffleMergeId
-      stream1.onComplete(stream1.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream1.getID());
-    }
+    // stream 1 push should be rejected as it is from an older shuffleMergeId
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream1.onComplete(stream1.getID()));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream1.getID());
     // stream 2 now completes
     stream2.onComplete(stream2.getID());
-    try {
-      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1));
-    } catch(RuntimeException re) {
-      assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale"
-        + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle"
-          + " is already being pushed", re.getMessage());
-    }
+    RuntimeException re = assertThrows(RuntimeException.class,
+      () -> pushResolver.finalizeShuffleMerge(
+              new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1)));
+    assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale"
+      + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle"
+      + " is already being pushed", re.getMessage());
     pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2));
 
     MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0);
@@ -1180,42 +1110,33 @@ public class RemoteBlockPushResolverSuite {
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2]));
     stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2]));
-    try {
-      // stream 1 push should be rejected as it is from an older shuffleMergeId
-      stream1.onComplete(stream1.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-        (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(),
-        errorCode.returnCode);
-      assertEquals(errorCode.failureBlockId, stream1.getID());
-    }
+    // stream 1 push should be rejected as it is from an older shuffleMergeId
+    BlockPushNonFatalFailure e = assertThrows(BlockPushNonFatalFailure.class,
+      () -> stream1.onComplete(stream1.getID()));
+    BlockPushReturnCode errorCode =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.STALE_BLOCK_PUSH.id(),
+      errorCode.returnCode);
+    assertEquals(errorCode.failureBlockId, stream1.getID());
     // stream 2 now completes
     stream2.onComplete(stream2.getID());
     pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 2));
-    try {
-      pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0);
-    } catch(RuntimeException re) {
-      assertEquals("MergedBlockMeta fetch for shuffle 0 with shuffleMergeId 0 reduceId 0"
-        + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for"
-        + " the shuffle is available", re.getMessage());
-    }
-
-    try {
-      pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1));
-    } catch(RuntimeException re) {
-      assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale"
-        + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle"
-          + " is already being pushed", re.getMessage());
-    }
-    try {
-      pushResolver.getMergedBlockData(TEST_APP, 0, 1, 0, 0);
-    } catch(RuntimeException re) {
-      assertEquals("MergedBlockData fetch for shuffle 0 with shuffleMergeId 1 reduceId 0"
-        + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for"
-        + " the shuffle is available", re.getMessage());
-    }
-
+    RuntimeException re0 = assertThrows(RuntimeException.class,
+      () -> pushResolver.getMergedBlockMeta(TEST_APP, 0, 0, 0));
+    assertEquals("MergedBlockMeta fetch for shuffle 0 with shuffleMergeId 0 reduceId 0"
+      + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for"
+      + " the shuffle is available", re0.getMessage());
+    RuntimeException re1 = assertThrows(RuntimeException.class,
+      () -> pushResolver.finalizeShuffleMerge(
+              new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 0, 1)));
+    assertEquals("Shuffle merge finalize request for shuffle 0 with shuffleMergeId 1 is stale"
+      + " shuffle finalize request as shuffle blocks of a higher shuffleMergeId for the shuffle"
+      + " is already being pushed", re1.getMessage());
+    RuntimeException re2 = assertThrows(RuntimeException.class,
+      () -> pushResolver.getMergedBlockData(TEST_APP, 0, 1, 0, 0));
+    assertEquals("MergedBlockData fetch for shuffle 0 with shuffleMergeId 1 reduceId 0"
+      + " is stale shuffle block fetch request as shuffle blocks of a higher shuffleMergeId for"
+      + " the shuffle is available", re2.getMessage());
     MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 2, 0);
     validateChunks(TEST_APP, 0, 2, 0, blockMeta, new int[]{4}, new int[][]{{0}});
   }
@@ -1324,37 +1245,28 @@ public class RemoteBlockPushResolverSuite {
     stream1.onComplete(stream1.getID());
     //shuffle 1 0 is finalized
     pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, NO_ATTEMPT_ID, 1, 0));
-    BlockPushNonFatalFailure errorToValidate = null;
-    try {
-      //shufflePush_1_0_0_200 is received by the server after finalization of shuffle 1 0 which
-      //should be rejected
-      StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 200, 0));
-      failureCallback.onComplete(failureCallback.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-          (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(),
-          errorCode.returnCode);
-      errorToValidate = e;
-      assertEquals(errorCode.failureBlockId, "shufflePush_1_0_0_200");
-    }
-    assertNotNull("shufflePush_1_0_0_200 should be rejected", errorToValidate);
-    try {
-      //shufflePush_1_0_1_100 is received by the server after finalization of shuffle 1 0 which
-      //should also be rejected
-      StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream(
-          new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 1, 100, 0));
-      failureCallback.onComplete(failureCallback.getID());
-    } catch (BlockPushNonFatalFailure e) {
-      BlockPushReturnCode errorCode =
-          (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e.getResponse());
-      assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(),
-          errorCode.returnCode);
-      errorToValidate = e;
-      assertEquals(errorCode.failureBlockId, "shufflePush_1_0_1_100");
-    }
-    assertNotNull("shufflePush_1_0_1_100 should be rejected", errorToValidate);
+    //shufflePush_1_0_0_200 is received by the server after finalization of shuffle 1 0 which
+    //should be rejected
+    StreamCallbackWithID failureCallback0 = pushResolver.receiveBlockDataAsStream(
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 0, 200, 0));
+    BlockPushNonFatalFailure e0 = assertThrows(BlockPushNonFatalFailure.class,
+      () -> failureCallback0.onComplete(failureCallback0.getID()));
+    BlockPushReturnCode errorCode0 =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e0.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(),
+      errorCode0.returnCode);
+    assertEquals(errorCode0.failureBlockId, "shufflePush_1_0_0_200");
+    //shufflePush_1_0_1_100 is received by the server after finalization of shuffle 1 0 which
+    //should also be rejected
+    StreamCallbackWithID failureCallback = pushResolver.receiveBlockDataAsStream(
+      new PushBlockStream(TEST_APP, NO_ATTEMPT_ID, 1, 0, 1, 100, 0));
+    BlockPushNonFatalFailure e1 = assertThrows(BlockPushNonFatalFailure.class,
+      () -> failureCallback.onComplete(failureCallback.getID()));
+    BlockPushReturnCode errorCode1 =
+      (BlockPushReturnCode) BlockTransferMessage.Decoder.fromByteBuffer(e1.getResponse());
+    assertEquals(BlockPushNonFatalFailure.ReturnCode.TOO_LATE_BLOCK_PUSH.id(),
+      errorCode1.returnCode);
+    assertEquals(errorCode1.failureBlockId, "shufflePush_1_0_1_100");
     MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 1, 0, 100);
     validateChunks(TEST_APP, 1, 0, 100, blockMeta, new int[]{4}, new int[][]{{0}});
     removeApplication(TEST_APP);
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index e433dc0..f530c81 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -613,12 +613,8 @@ public class UTF8StringSuite {
 
     for (final long offset : offsets) {
       try {
-        fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length)
-            .writeTo(outputStream);
-
-        throw new IllegalStateException(Long.toString(offset));
-      } catch (ArrayIndexOutOfBoundsException e) {
-        // ignore
+        assertThrows(ArrayIndexOutOfBoundsException.class,
+          () -> fromAddress(test, BYTE_ARRAY_OFFSET + offset, test.length).writeTo(outputStream));
       } finally {
         outputStream.reset();
       }
diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
index 72b1245..5c88fb6 100644
--- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
+++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java
@@ -51,20 +51,11 @@ public class SparkLauncherSuite extends BaseSuite {
     SparkSubmitOptionParser opts = new SparkSubmitOptionParser();
 
     launcher.addSparkArg(opts.HELP);
-    try {
-      launcher.addSparkArg(opts.PROXY_USER);
-      fail("Expected IllegalArgumentException.");
-    } catch (IllegalArgumentException e) {
-      // Expected.
-    }
+    assertThrows(IllegalArgumentException.class, () -> launcher.addSparkArg(opts.PROXY_USER));
 
     launcher.addSparkArg(opts.PROXY_USER, "someUser");
-    try {
-      launcher.addSparkArg(opts.HELP, "someValue");
-      fail("Expected IllegalArgumentException.");
-    } catch (IllegalArgumentException e) {
-      // Expected.
-    }
+    assertThrows(IllegalArgumentException.class,
+      () -> launcher.addSparkArg(opts.HELP, "someValue"));
 
     launcher.addSparkArg("--future-argument");
     launcher.addSparkArg("--future-argument", "someValue");
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
index 92bc740..1fd5aab 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/PackedRecordPointerSuite.java
@@ -30,7 +30,7 @@ import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PAGE_SIZ
 import static org.apache.spark.shuffle.sort.PackedRecordPointer.MAXIMUM_PARTITION_ID;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertThrows;
 
 public class PackedRecordPointerSuite {
 
@@ -86,15 +86,9 @@ public class PackedRecordPointerSuite {
   @Test
   public void partitionIdsGreaterThanMaximumPartitionIdWillOverflowOrTriggerError() {
     PackedRecordPointer packedPointer = new PackedRecordPointer();
-    boolean asserted = false;
-    try {
-      // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
-      packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1));
-    } catch (AssertionError e ) {
-      // pass
-      asserted = true;
-    }
-    assertTrue(asserted);
+    // Pointers greater than the maximum partition ID will overflow or trigger an assertion error
+    assertThrows(AssertionError.class,
+      () -> packedPointer.set(PackedRecordPointer.packPointer(0, MAXIMUM_PARTITION_ID + 1)));
     assertNotEquals(MAXIMUM_PARTITION_ID + 1, packedPointer.getPartitionId());
   }
 
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index a9c81c5..a20a2a0 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -618,33 +618,14 @@ public abstract class AbstractBytesToBytesMapSuite {
 
   @Test
   public void initialCapacityBoundsChecking() {
-    try {
-      new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES);
-      Assert.fail("Expected IllegalArgumentException to be thrown");
-    } catch (IllegalArgumentException e) {
-      // expected exception
-    }
-
-    try {
-      new BytesToBytesMap(
-        taskMemoryManager,
-        BytesToBytesMap.MAX_CAPACITY + 1,
-        PAGE_SIZE_BYTES);
-      Assert.fail("Expected IllegalArgumentException to be thrown");
-    } catch (IllegalArgumentException e) {
-      // expected exception
-    }
-
-    try {
-      new BytesToBytesMap(
-        taskMemoryManager,
-        1,
-        TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES + 1);
-      Assert.fail("Expected IllegalArgumentException to be thrown");
-    } catch (IllegalArgumentException e) {
-      // expected exception
-    }
-
+    assertThrows(IllegalArgumentException.class,
+      () -> new BytesToBytesMap(taskMemoryManager, 0, PAGE_SIZE_BYTES));
+    assertThrows(IllegalArgumentException.class,
+      () -> new BytesToBytesMap(taskMemoryManager,
+              BytesToBytesMap.MAX_CAPACITY + 1, PAGE_SIZE_BYTES));
+    assertThrows(IllegalArgumentException.class,
+      () -> new BytesToBytesMap(taskMemoryManager, 1,
+              TaskMemoryManager.MAXIMUM_PAGE_SIZE_BYTES + 1));
   }
 
   @Test
@@ -742,10 +723,7 @@ public abstract class AbstractBytesToBytesMapSuite {
     // Force OOM on next memory allocation.
     memoryManager.markExecutionAsOutOfMemoryOnce();
     try {
-      map.reset();
-      Assert.fail("Expected SparkOutOfMemoryError to be thrown");
-    } catch (SparkOutOfMemoryError e) {
-      // Expected exception; do nothing.
+      assertThrows(SparkOutOfMemoryError.class, map::reset);
     } finally {
       map.free();
     }
diff --git a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java
index fd91237..cba43d9 100644
--- a/core/src/test/java/test/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/test/org/apache/spark/JavaAPISuite.java
@@ -1494,12 +1494,7 @@ public class JavaAPISuite implements Serializable {
     future.cancel(true);
     assertTrue(future.isCancelled());
     assertTrue(future.isDone());
-    try {
-      future.get(2000, TimeUnit.MILLISECONDS);
-      fail("Expected future.get() for cancelled job to throw CancellationException");
-    } catch (CancellationException ignored) {
-      // pass
-    }
+    assertThrows(CancellationException.class, () -> future.get(2000, TimeUnit.MILLISECONDS));
   }
 
   @Test
@@ -1507,12 +1502,9 @@ public class JavaAPISuite implements Serializable {
     List<Integer> data = Arrays.asList(1, 2, 3, 4, 5);
     JavaRDD<Integer> rdd = sc.parallelize(data, 1);
     JavaFutureAction<Long> future = rdd.map(new BuggyMapFunction<>()).countAsync();
-    try {
-      future.get(2, TimeUnit.SECONDS);
-      fail("Expected future.get() for failed job to throw ExecutionException");
-    } catch (ExecutionException ee) {
-      assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
-    }
+    ExecutionException ee = assertThrows(ExecutionException.class,
+      () -> future.get(2, TimeUnit.SECONDS));
+    assertTrue(Throwables.getStackTraceAsString(ee).contains("Custom exception!"));
     assertTrue(future.isDone());
   }
 
diff --git a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java
index 22d9324..46cdffc 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/CommandBuilderUtilsSuite.java
@@ -105,12 +105,7 @@ public class CommandBuilderUtilsSuite {
   }
 
   private static void testInvalidOpt(String opts) {
-    try {
-      parseOptionString(opts);
-      fail("Expected exception for invalid option string.");
-    } catch (IllegalArgumentException e) {
-      // pass.
-    }
+    assertThrows(IllegalArgumentException.class, () -> parseOptionString(opts));
   }
 
 }
diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
index f8dc0ec..bf89de9 100644
--- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
+++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java
@@ -230,16 +230,14 @@ public class LauncherServerSuite extends BaseSuite {
   private void waitForError(TestClient client, String secret) throws Exception {
     final AtomicBoolean helloSent = new AtomicBoolean();
     eventually(Duration.ofSeconds(1), Duration.ofMillis(10), () -> {
-      try {
-        if (!helloSent.get()) {
-          client.send(new Hello(secret, "1.4.0"));
-          helloSent.set(true);
+      if (!helloSent.get()) {
+        if (client.isOpen()) {
+          assertThrows(IOException.class, () -> client.send(new SetAppId("appId")));
         } else {
-          client.send(new SetAppId("appId"));
+          assertThrows(IllegalStateException.class,
+            () -> client.send(new Hello(secret, "1.4.0")));
+          helloSent.set(true);
         }
-        fail("Expected error but message went through.");
-      } catch (IllegalStateException | IOException e) {
-        // Expected.
       }
     });
   }
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index dd98513..5308d61 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -73,12 +73,8 @@ public class JavaRandomForestClassifierSuite extends SharedSparkSession {
     }
     String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
     for (String strategy : invalidStrategies) {
-      try {
-        rf.setFeatureSubsetStrategy(strategy);
-        Assert.fail("Expected exception to be thrown for invalid strategies");
-      } catch (Exception e) {
-        Assert.assertTrue(e instanceof IllegalArgumentException);
-      }
+      Assert.assertThrows(IllegalArgumentException.class,
+        () -> rf.setFeatureSubsetStrategy(strategy));
     }
 
     RandomForestClassificationModel model = rf.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 4ba13e2..d08040d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -75,12 +75,8 @@ public class JavaRandomForestRegressorSuite extends SharedSparkSession {
     }
     String[] invalidStrategies = {"-.1", "-.10", "-0.10", ".0", "0.0", "1.1", "0"};
     for (String strategy : invalidStrategies) {
-      try {
-        rf.setFeatureSubsetStrategy(strategy);
-        Assert.fail("Expected exception to be thrown for invalid strategies");
-      } catch (Exception e) {
-        Assert.assertTrue(e instanceof IllegalArgumentException);
-      }
+      Assert.assertThrows(IllegalArgumentException.class,
+        () -> rf.setFeatureSubsetStrategy(strategy));
     }
 
     RandomForestRegressionModel model = rf.fit(dataFrame);
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
index e4f678f..e4287c4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/util/JavaDefaultReadWriteSuite.java
@@ -49,13 +49,7 @@ public class JavaDefaultReadWriteSuite extends SharedSparkSession {
     instance.set(instance.intParam(), 2);
     String outputPath = new File(tempDir, uid).getPath();
     instance.save(outputPath);
-    try {
-      instance.save(outputPath);
-      Assert.fail(
-        "Write without overwrite enabled should fail if the output directory already exists.");
-    } catch (IOException e) {
-      // expected
-    }
+    Assert.assertThrows(IOException.class, () -> instance.save(outputPath));
     instance.write().session(spark).overwrite().save(outputPath);
     MyParams newInstance = MyParams.load(outputPath);
     Assert.assertEquals("UID should match.", instance.uid(), newInstance.uid());
diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
index d460a06..c7fdcc6 100644
--- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
+++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalyst/expressions/RowBasedKeyValueBatchSuite.java
@@ -127,43 +127,10 @@ public class RowBasedKeyValueBatchSuite {
     try (RowBasedKeyValueBatch batch = RowBasedKeyValueBatch.allocate(keySchema,
         valueSchema, taskMemoryManager, DEFAULT_CAPACITY)) {
       Assert.assertEquals(0, batch.numRows());
-
-      boolean asserted = false;
-      try {
-        batch.getKeyRow(-1);
-      } catch (AssertionError e) {
-        // Expected exception; do nothing.
-        asserted = true;
-      }
-      Assert.assertTrue("Should not be able to get row -1", asserted);
-
-      asserted = false;
-      try {
-        batch.getValueRow(-1);
-      } catch (AssertionError e) {
-        // Expected exception; do nothing.
-        asserted = true;
-      }
-      Assert.assertTrue("Should not be able to get row -1", asserted);
-
-      asserted = false;
-      try {
-        batch.getKeyRow(0);
-      } catch (AssertionError e) {
-        // Expected exception; do nothing.
-        asserted = true;
-      }
-      Assert.assertTrue("Should not be able to get row 0 when batch is empty", asserted);
-
-      asserted = false;
-      try {
-        batch.getValueRow(0);
-      } catch (AssertionError e) {
-        // Expected exception; do nothing.
-        asserted = true;
-      }
-      Assert.assertTrue("Should not be able to get row 0 when batch is empty", asserted);
-
+      Assert.assertThrows(AssertionError.class, () -> batch.getKeyRow(-1));
+      Assert.assertThrows(AssertionError.class, () -> batch.getValueRow(-1));
+      Assert.assertThrows(AssertionError.class, () -> batch.getKeyRow(0));
+      Assert.assertThrows(AssertionError.class, () -> batch.getValueRow(0));
       Assert.assertFalse(batch.rowIterator().next());
     }
   }
@@ -199,23 +166,8 @@ public class RowBasedKeyValueBatchSuite {
       UnsafeRow retrievedValue2 = batch.getValueRow(2);
       Assert.assertTrue(checkValue(retrievedValue2, 3, 3));
 
-      boolean asserted = false;
-      try {
-        batch.getKeyRow(3);
-      } catch (AssertionError e) {
-        // Expected exception; do nothing.
-        asserted = true;
-      }
-      Assert.assertTrue("Should not be able to get row 3", asserted);
-
-      asserted = false;
-      try {
-        batch.getValueRow(3);
-      } catch (AssertionError e) {
-        // Expected exception; do nothing.
-        asserted = true;
-      }
-      Assert.assertTrue("Should not be able to get row 3", asserted);
+      Assert.assertThrows(AssertionError.class, () -> batch.getKeyRow(3));
+      Assert.assertThrows(AssertionError.class, () -> batch.getValueRow(3));
     }
   }
 
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
index af0a22b..3d78e06 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
@@ -183,18 +183,9 @@ public class JavaBeanDeserializationSuite implements Serializable {
 
     Dataset<Row> dataFrame = spark.createDataFrame(inputRows, schema);
 
-    try {
-      dataFrame.as(encoder).collect();
-      Assert.fail("Expected AnalysisException, but passed.");
-    } catch (Throwable e) {
-      // Here we need to handle weird case: compiler complains AnalysisException never be thrown
-      // in try statement, but it can be thrown actually. Maybe Scala-Java interop issue?
-      if (e instanceof AnalysisException) {
-        Assert.assertTrue(e.getMessage().contains("Cannot up cast "));
-      } else {
-        throw e;
-      }
-    }
+    AnalysisException e = Assert.assertThrows(AnalysisException.class,
+      () -> dataFrame.as(encoder).collect());
+    Assert.assertTrue(e.getMessage().contains("Cannot up cast "));
   }
 
   private static Row createRecordSpark22000Row(Long index) {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java
index 4478742..7f9fdbd 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaColumnExpressionSuite.java
@@ -79,14 +79,12 @@ public class JavaColumnExpressionSuite {
       createStructField("a", IntegerType, false),
       createStructField("b", createArrayType(IntegerType, false), false)));
     Dataset<Row> df = spark.createDataFrame(rows, schema);
-    try {
-      df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b"))));
-      Assert.fail("Expected org.apache.spark.sql.AnalysisException");
-    } catch (Exception e) {
-      Arrays.asList("cannot resolve",
-        "due to data type mismatch: Arguments must be same type but were")
-        .forEach(s -> Assert.assertTrue(
-          e.getMessage().toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))));
-    }
+    Exception e = Assert.assertThrows(Exception.class,
+      () -> df.filter(df.col("a").isInCollection(Arrays.asList(new Column("b")))));
+    Arrays.asList("cannot resolve",
+      "due to data type mismatch: Arguments must be same type but were")
+        .forEach(s ->
+          Assert.assertTrue(e.getMessage().toLowerCase(Locale.ROOT)
+            .contains(s.toLowerCase(Locale.ROOT))));
   }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org