You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/01/15 14:55:27 UTC

[spark] branch master updated: [SPARK-37854][CORE] Replace type check with pattern matching in Spark code

This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c7c51bc  [SPARK-37854][CORE] Replace type check with pattern matching in Spark code
c7c51bc is described below

commit c7c51bcab5cb067d36bccf789e0e4ad7f37ffb7c
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Sat Jan 15 08:54:16 2022 -0600

    [SPARK-37854][CORE] Replace type check with pattern matching in Spark code
    
    ### What changes were proposed in this pull request?
    
    There are many method use `isInstanceOf  + asInstanceOf` for type conversion in Spark code now, the main change of this pr is replace `type check` with `pattern matching` for code simplification.
    
    ### Why are the changes needed?
    Code simplification
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Pass GA
    
    Closes #35154 from LuciferYang/SPARK-37854.
    
    Authored-by: yangjie01 <ya...@baidu.com>
    Signed-off-by: Sean Owen <sr...@gmail.com>
---
 .../main/scala/org/apache/spark/TestUtils.scala    | 36 ++++++------
 .../main/scala/org/apache/spark/api/r/SerDe.scala  | 12 ++--
 .../spark/internal/config/ConfigBuilder.scala      | 18 +++---
 .../scala/org/apache/spark/rdd/HadoopRDD.scala     | 64 +++++++++++-----------
 .../main/scala/org/apache/spark/rdd/PipedRDD.scala |  7 ++-
 core/src/main/scala/org/apache/spark/rdd/RDD.scala |  8 ++-
 .../main/scala/org/apache/spark/util/Utils.scala   | 38 ++++++-------
 .../storage/ShuffleBlockFetcherIteratorSuite.scala | 10 ++--
 .../org/apache/spark/util/FileAppenderSuite.scala  | 17 +++---
 .../scala/org/apache/spark/util/UtilsSuite.scala   | 19 ++++---
 .../apache/spark/examples/mllib/LDAExample.scala   | 11 ++--
 .../spark/mllib/api/python/PythonMLLibAPI.scala    | 12 ++--
 .../expressions/aggregate/Percentile.scala         | 14 ++---
 .../apache/spark/sql/catalyst/trees/TreeNode.scala |  7 +--
 .../sql/catalyst/encoders/RowEncoderSuite.scala    | 11 ++--
 .../sql/execution/columnar/ColumnAccessor.scala    | 10 ++--
 .../spark/sql/execution/columnar/ColumnType.scala  | 50 +++++++++--------
 .../sql/execution/datasources/FileScanRDD.scala    | 19 ++++---
 .../org/apache/spark/sql/jdbc/H2Dialect.scala      | 30 +++++-----
 .../spark/sql/SparkSessionExtensionSuite.scala     | 57 +++++++++----------
 .../sql/execution/joins/BroadcastJoinSuite.scala   | 13 ++---
 .../apache/spark/sql/streaming/StreamTest.scala    |  6 +-
 .../sql/hive/client/IsolatedClientLoader.scala     | 12 ++--
 .../spark/streaming/scheduler/JobGenerator.scala   | 10 ++--
 .../org/apache/spark/streaming/util/StateMap.scala | 21 +++----
 25 files changed, 263 insertions(+), 249 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 20159af..d2af955 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -337,22 +337,26 @@ private[spark] object TestUtils {
     connection.setRequestMethod(method)
     headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }
 
-    // Disable cert and host name validation for HTTPS tests.
-    if (connection.isInstanceOf[HttpsURLConnection]) {
-      val sslCtx = SSLContext.getInstance("SSL")
-      val trustManager = new X509TrustManager {
-        override def getAcceptedIssuers(): Array[X509Certificate] = null
-        override def checkClientTrusted(x509Certificates: Array[X509Certificate],
-            s: String): Unit = {}
-        override def checkServerTrusted(x509Certificates: Array[X509Certificate],
-            s: String): Unit = {}
-      }
-      val verifier = new HostnameVerifier() {
-        override def verify(hostname: String, session: SSLSession): Boolean = true
-      }
-      sslCtx.init(null, Array(trustManager), new SecureRandom())
-      connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
-      connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
+    connection match {
+      // Disable cert and host name validation for HTTPS tests.
+      case httpConnection: HttpsURLConnection =>
+        val sslCtx = SSLContext.getInstance("SSL")
+        val trustManager = new X509TrustManager {
+          override def getAcceptedIssuers: Array[X509Certificate] = null
+
+          override def checkClientTrusted(x509Certificates: Array[X509Certificate],
+              s: String): Unit = {}
+
+          override def checkServerTrusted(x509Certificates: Array[X509Certificate],
+              s: String): Unit = {}
+        }
+        val verifier = new HostnameVerifier() {
+          override def verify(hostname: String, session: SSLSession): Boolean = true
+        }
+        sslCtx.init(null, Array(trustManager), new SecureRandom())
+        httpConnection.setSSLSocketFactory(sslCtx.getSocketFactory)
+        httpConnection.setHostnameVerifier(verifier)
+      case _ => // do nothing
     }
 
     try {
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 9172038..f9f8c56 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets
 import java.sql.{Date, Time, Timestamp}
 
 import scala.collection.JavaConverters._
-import scala.collection.mutable.WrappedArray
+import scala.collection.mutable
 
 /**
  * Utility functions to serialize, deserialize objects to / from R
@@ -303,12 +303,10 @@ private[spark] object SerDe {
       // Convert ArrayType collected from DataFrame to Java array
       // Collected data of ArrayType from a DataFrame is observed to be of
       // type "scala.collection.mutable.WrappedArray"
-      val value =
-        if (obj.isInstanceOf[WrappedArray[_]]) {
-          obj.asInstanceOf[WrappedArray[_]].toArray
-        } else {
-          obj
-        }
+      val value = obj match {
+        case wa: mutable.WrappedArray[_] => wa.array
+        case other => other
+      }
 
       value match {
         case v: java.lang.Character =>
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
index 38e057b..e319026 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -140,15 +140,15 @@ private[spark] class TypedConfigBuilder[T](
   def createWithDefault(default: T): ConfigEntry[T] = {
     // Treat "String" as a special case, so that both createWithDefault and createWithDefaultString
     // behave the same w.r.t. variable expansion of default values.
-    if (default.isInstanceOf[String]) {
-      createWithDefaultString(default.asInstanceOf[String])
-    } else {
-      val transformedDefault = converter(stringConverter(default))
-      val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey,
-        parent._prependSeparator, parent._alternatives, transformedDefault, converter,
-        stringConverter, parent._doc, parent._public, parent._version)
-      parent._onCreate.foreach(_(entry))
-      entry
+    default match {
+      case str: String => createWithDefaultString(str)
+      case _ =>
+        val transformedDefault = converter(stringConverter(default))
+        val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey,
+          parent._prependSeparator, parent._alternatives, transformedDefault, converter,
+          stringConverter, parent._doc, parent._public, parent._version)
+        parent._onCreate.foreach(_ (entry))
+        entry
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 7011451..fcc2275 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -61,14 +61,14 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp
    * @return a Map with the environment variables and corresponding values, it could be empty
    */
   def getPipeEnvVars(): Map[String, String] = {
-    val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
-      val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
-      // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
-      // since it's not removed yet
-      Map("map_input_file" -> is.getPath().toString(),
-        "mapreduce_map_input_file" -> is.getPath().toString())
-    } else {
-      Map()
+    val envVars: Map[String, String] = inputSplit.value match {
+      case is: FileSplit =>
+        // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
+        // since it's not removed yet
+        Map("map_input_file" -> is.getPath().toString(),
+          "mapreduce_map_input_file" -> is.getPath().toString())
+      case _ =>
+        Map()
     }
     envVars
   }
@@ -161,29 +161,31 @@ class HadoopRDD[K, V](
         newJobConf
       }
     } else {
-      if (conf.isInstanceOf[JobConf]) {
-        logDebug("Re-using user-broadcasted JobConf")
-        conf.asInstanceOf[JobConf]
-      } else {
-        Option(HadoopRDD.getCachedMetadata(jobConfCacheKey))
-          .map { conf =>
-            logDebug("Re-using cached JobConf")
-            conf.asInstanceOf[JobConf]
-          }
-          .getOrElse {
-            // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in
-            // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
-            // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary
-            // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097,
-            // HADOOP-10456).
-            HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
-              logDebug("Creating new JobConf and caching it for later re-use")
-              val newJobConf = new JobConf(conf)
-              initLocalJobConfFuncOpt.foreach(f => f(newJobConf))
-              HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
-              newJobConf
-          }
-        }
+      conf match {
+        case jobConf: JobConf =>
+          logDebug("Re-using user-broadcasted JobConf")
+          jobConf
+        case _ =>
+          Option(HadoopRDD.getCachedMetadata(jobConfCacheKey))
+            .map { conf =>
+              logDebug("Re-using cached JobConf")
+              conf.asInstanceOf[JobConf]
+            }
+            .getOrElse {
+              // Create a JobConf that will be cached and used across this RDD's getJobConf()
+              // calls in the local process. The local cache is accessed through
+              // HadoopRDD.putCachedMetadata().
+              // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary
+              // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097,
+              // HADOOP-10456).
+              HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+                logDebug("Creating new JobConf and caching it for later re-use")
+                val newJobConf = new JobConf(conf)
+                initLocalJobConfFuncOpt.foreach(f => f(newJobConf))
+                HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+                newJobConf
+              }
+            }
       }
     }
   }
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 285da04..7e121e9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -72,9 +72,10 @@ private[spark] class PipedRDD[T: ClassTag](
 
     // for compatibility with Hadoop which sets these env variables
     // so the user code can access the input filename
-    if (split.isInstanceOf[HadoopPartition]) {
-      val hadoopSplit = split.asInstanceOf[HadoopPartition]
-      currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava)
+    split match {
+      case hadoopSplit: HadoopPartition =>
+        currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava)
+      case _ => // do nothing
     }
 
     // When spark.worker.separated.working.directory option is turned on, each
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 4c39d17..7188566 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1764,9 +1764,11 @@ abstract class RDD[T: ClassTag](
        * Clean the shuffles & all of its parents.
        */
       def cleanEagerly(dep: Dependency[_]): Unit = {
-        if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) {
-          val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
-          cleaner.doCleanupShuffle(shuffleId, blocking)
+        dep match {
+          case dependency: ShuffleDependency[_, _, _] =>
+            val shuffleId = dependency.shuffleId
+            cleaner.doCleanupShuffle(shuffleId, blocking)
+          case _ => // do nothing
         }
         val rdd = dep.rdd
         val rddDepsOpt = rdd.internalDependencies
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 8f3d1de..a9d6180 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -355,26 +355,26 @@ private[spark] object Utils extends Logging {
       closeStreams: Boolean = false,
       transferToEnabled: Boolean = false): Long = {
     tryWithSafeFinally {
-      if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]
-        && transferToEnabled) {
-        // When both streams are File stream, use transferTo to improve copy performance.
-        val inChannel = in.asInstanceOf[FileInputStream].getChannel()
-        val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
-        val size = inChannel.size()
-        copyFileStreamNIO(inChannel, outChannel, 0, size)
-        size
-      } else {
-        var count = 0L
-        val buf = new Array[Byte](8192)
-        var n = 0
-        while (n != -1) {
-          n = in.read(buf)
-          if (n != -1) {
-            out.write(buf, 0, n)
-            count += n
+      (in, out) match {
+        case (input: FileInputStream, output: FileOutputStream) if transferToEnabled =>
+          // When both streams are File stream, use transferTo to improve copy performance.
+          val inChannel = input.getChannel
+          val outChannel = output.getChannel
+          val size = inChannel.size()
+          copyFileStreamNIO(inChannel, outChannel, 0, size)
+          size
+        case (input, output) =>
+          var count = 0L
+          val buf = new Array[Byte](8192)
+          var n = 0
+          while (n != -1) {
+            n = input.read(buf)
+            if (n != -1) {
+              output.write(buf, 0, n)
+              count += n
+            }
           }
-        }
-        count
+          count
       }
     } {
       if (closeStreams) {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index afb9a86..56043ea 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -160,10 +160,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
     verify(buffer, times(0)).release()
     val delegateAccess = PrivateMethod[InputStream](Symbol("delegate"))
     var in = wrappedInputStream.invokePrivate(delegateAccess())
-    if (in.isInstanceOf[CheckedInputStream]) {
-      val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in")
-      underlyingInputFiled.setAccessible(true)
-      in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream]
+    in match {
+      case stream: CheckedInputStream =>
+        val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in")
+        underlyingInputFiled.setAccessible(true)
+        in = underlyingInputFiled.get(stream).asInstanceOf[InputStream]
+      case _ => // do nothing
     }
     verify(in, times(0)).close()
     wrappedInputStream.close()
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
index 1a2eb69..8ca4bc9 100644
--- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -222,14 +222,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging {
       // assert(appender.getClass === classTag[ExpectedAppender].getClass)
       assert(appender.getClass.getSimpleName ===
         classTag[ExpectedAppender].runtimeClass.getSimpleName)
-      if (appender.isInstanceOf[RollingFileAppender]) {
-        val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy
-        val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) {
-          rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis
-        } else {
-          rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes
-        }
-        assert(policyParam === expectedRollingPolicyParam)
+      appender match {
+        case rfa: RollingFileAppender =>
+          val rollingPolicy = rfa.rollingPolicy
+          val policyParam = rollingPolicy match {
+            case timeBased: TimeBasedRollingPolicy => timeBased.rolloverIntervalMillis
+            case sizeBased: SizeBasedRollingPolicy => sizeBased.rolloverSizeBytes
+          }
+          assert(policyParam === expectedRollingPolicyParam)
+        case _ => // do nothing
       }
       testOutputStream.close()
       appender.awaitTermination()
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 6117dec..62cd819 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -227,15 +227,16 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
       try {
         // Get a handle on the buffered data, to make sure memory gets freed once we read past the
         // end of it. Need to use reflection to get handle on inner structures for this check
-        val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) {
-          assert(inputLength < limit)
-          mergedStream.asInstanceOf[ChunkedByteBufferInputStream]
-        } else {
-          assert(inputLength >= limit)
-          val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream]
-          val fieldValue = getFieldValue(sequenceStream, "in")
-          assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream])
-          fieldValue.asInstanceOf[ChunkedByteBufferInputStream]
+        val byteBufferInputStream = mergedStream match {
+          case stream: ChunkedByteBufferInputStream =>
+            assert(inputLength < limit)
+            stream
+          case _ =>
+            assert(inputLength >= limit)
+            val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream]
+            val fieldValue = getFieldValue(sequenceStream, "in")
+            assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream])
+            fieldValue.asInstanceOf[ChunkedByteBufferInputStream]
         }
         (0 until inputLength).foreach { idx =>
           assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte])
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index a3006a1..afd529c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -158,11 +158,12 @@ object LDAExample {
     println(s"Finished training LDA model.  Summary:")
     println(s"\t Training time: $elapsed sec")
 
-    if (ldaModel.isInstanceOf[DistributedLDAModel]) {
-      val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
-      val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
-      println(s"\t Training data average log likelihood: $avgLogLikelihood")
-      println()
+    ldaModel match {
+      case distLDAModel: DistributedLDAModel =>
+        val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
+        println(s"\t Training data average log likelihood: $avgLogLikelihood")
+        println()
+      case _ => // do nothing
     }
 
     // Print the topics, showing the top-weighted terms for each topic.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 80707f0..56aaaa3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -90,12 +90,12 @@ private[python] class PythonMLLibAPI extends Serializable {
       initialWeights: Vector): JList[Object] = {
     try {
       val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
-      if (model.isInstanceOf[LogisticRegressionModel]) {
-        val lrModel = model.asInstanceOf[LogisticRegressionModel]
-        List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
-          .map(_.asInstanceOf[Object]).asJava
-      } else {
-        List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+      model match {
+        case lrModel: LogisticRegressionModel =>
+          List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
+            .map(_.asInstanceOf[Object]).asJava
+        case _ =>
+          List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
       }
     } finally {
       data.rdd.unpersist()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 7d3dd0a..a98585e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -198,14 +198,12 @@ case class Percentile(
       return Seq.empty
     }
 
-    val ordering =
-      if (child.dataType.isInstanceOf[NumericType]) {
-        child.dataType.asInstanceOf[NumericType].ordering
-      } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) {
-        child.dataType.asInstanceOf[YearMonthIntervalType].ordering
-      } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) {
-        child.dataType.asInstanceOf[DayTimeIntervalType].ordering
-      }
+    val ordering = child.dataType match {
+      case numericType: NumericType => numericType.ordering
+      case intervalType: YearMonthIntervalType => intervalType.ordering
+      case intervalType: DayTimeIntervalType => intervalType.ordering
+      case otherType => QueryExecutionErrors.unsupportedTypeError(otherType)
+    }
     val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]])
     val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
       case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f78bbbf..9e50be3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -341,10 +341,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
   // This is a temporary solution, we will change the type of children to IndexedSeq in a
   // followup PR
   private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = {
-    if (seq.isInstanceOf[IndexedSeq[BaseType]]) {
-      seq.asInstanceOf[IndexedSeq[BaseType]]
-    } else {
-      seq.toIndexedSeq
+    seq match {
+      case types: IndexedSeq[BaseType] => types
+      case other => other.toIndexedSeq
     }
   }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 1a42784..44b06d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -31,12 +31,11 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM
 class ExamplePoint(val x: Double, val y: Double) extends Serializable {
   override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
   override def equals(that: Any): Boolean = {
-    if (that.isInstanceOf[ExamplePoint]) {
-      val e = that.asInstanceOf[ExamplePoint]
-      (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
-        (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
-    } else {
-      false
+    that match {
+      case e: ExamplePoint =>
+        (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
+          (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
+      case _ => false
     }
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 2f68e89..fa7140b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -158,11 +158,11 @@ private[sql] object ColumnAccessor {
 
   def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int):
       Unit = {
-    if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) {
-      val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]]
-      nativeAccessor.decompress(columnVector, numRows)
-    } else {
-      throw QueryExecutionErrors.notSupportNonPrimitiveTypeError()
+    columnAccessor match {
+      case nativeAccessor: NativeColumnAccessor[_] =>
+        nativeAccessor.decompress(columnVector, numRows)
+      case _ =>
+        throw QueryExecutionErrors.notSupportNonPrimitiveTypeError()
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index 419dcc6..9b4c136 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -473,23 +473,25 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
 
   // copy the bytes from ByteBuffer to UnsafeRow
   override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
-    if (row.isInstanceOf[MutableUnsafeRow]) {
-      val numBytes = buffer.getInt
-      val cursor = buffer.position()
-      buffer.position(cursor + numBytes)
-      row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(),
-        buffer.arrayOffset() + cursor, numBytes)
-    } else {
-      setField(row, ordinal, extract(buffer))
+    row match {
+      case mutable: MutableUnsafeRow =>
+        val numBytes = buffer.getInt
+        val cursor = buffer.position()
+        buffer.position(cursor + numBytes)
+        mutable.writer.write(ordinal, buffer.array(),
+          buffer.arrayOffset() + cursor, numBytes)
+      case _ =>
+        setField(row, ordinal, extract(buffer))
     }
   }
 
   // copy the bytes from UnsafeRow to ByteBuffer
   override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
-    if (row.isInstanceOf[UnsafeRow]) {
-      row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer)
-    } else {
-      super.append(row, ordinal, buffer)
+    row match {
+      case unsafe: UnsafeRow =>
+        unsafe.writeFieldTo(ordinal, buffer)
+      case _ =>
+        super.append(row, ordinal, buffer)
     }
   }
 }
@@ -514,10 +516,11 @@ private[columnar] object STRING
   }
 
   override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = {
-    if (row.isInstanceOf[MutableUnsafeRow]) {
-      row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value)
-    } else {
-      row.update(ordinal, value.clone())
+    row match {
+      case mutable: MutableUnsafeRow =>
+        mutable.writer.write(ordinal, value)
+      case _ =>
+        row.update(ordinal, value.clone())
     }
   }
 
@@ -792,13 +795,14 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval]
 
   // copy the bytes from ByteBuffer to UnsafeRow
   override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
-    if (row.isInstanceOf[MutableUnsafeRow]) {
-      val cursor = buffer.position()
-      buffer.position(cursor + defaultSize)
-      row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(),
-        buffer.arrayOffset() + cursor, defaultSize)
-    } else {
-      setField(row, ordinal, extract(buffer))
+    row match {
+      case mutable: MutableUnsafeRow =>
+        val cursor = buffer.position()
+        buffer.position(cursor + defaultSize)
+        mutable.writer.write(ordinal, buffer.array(),
+          buffer.arrayOffset() + cursor, defaultSize)
+      case _ =>
+        setField(row, ordinal, extract(buffer))
     }
   }
 
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 47f279b..5baa597 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -214,16 +214,17 @@ class FileScanRDD(
         val nextElement = currentIterator.next()
         // TODO: we should have a better separation of row based and batch based scan, so that we
         // don't need to run this `if` for every record.
-        if (nextElement.isInstanceOf[ColumnarBatch]) {
-          incTaskInputMetricsBytesRead()
-          inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows())
-        } else {
-          // too costly to update every record
-          if (inputMetrics.recordsRead %
-              SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+        nextElement match {
+          case batch: ColumnarBatch =>
             incTaskInputMetricsBytesRead()
-          }
-          inputMetrics.incRecordsRead(1)
+            inputMetrics.incRecordsRead(batch.numRows())
+          case _ =>
+            // too costly to update every record
+            if (inputMetrics.recordsRead %
+              SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+              incTaskInputMetricsBytesRead()
+            }
+            inputMetrics.incRecordsRead(1)
         }
         addMetadataColumnsIfNeeded(nextElement)
       }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 1f422e5..7bd51f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -65,20 +65,22 @@ private object H2Dialect extends JdbcDialect {
   }
 
   override def classifyException(message: String, e: Throwable): AnalysisException = {
-    if (e.isInstanceOf[SQLException]) {
-      // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
-      e.asInstanceOf[SQLException].getErrorCode match {
-        // TABLE_OR_VIEW_ALREADY_EXISTS_1
-        case 42101 =>
-          throw new TableAlreadyExistsException(message, cause = Some(e))
-        // TABLE_OR_VIEW_NOT_FOUND_1
-        case 42102 =>
-          throw new NoSuchTableException(message, cause = Some(e))
-        // SCHEMA_NOT_FOUND_1
-        case 90079 =>
-          throw new NoSuchNamespaceException(message, cause = Some(e))
-        case _ =>
-      }
+    e match {
+      case exception: SQLException =>
+        // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
+        exception.getErrorCode match {
+          // TABLE_OR_VIEW_ALREADY_EXISTS_1
+          case 42101 =>
+            throw new TableAlreadyExistsException(message, cause = Some(e))
+          // TABLE_OR_VIEW_NOT_FOUND_1
+          case 42102 =>
+            throw NoSuchTableException(message, cause = Some(e))
+          // SCHEMA_NOT_FOUND_1
+          case 90079 =>
+            throw NoSuchNamespaceException(message, cause = Some(e))
+          case _ => // do nothing
+        }
+      case _ => // do nothing
     }
     super.classifyException(message, e)
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 4994968..3577812 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -725,37 +725,32 @@ class BrokenColumnarAdd(
       lhs = left.columnarEval(batch)
       rhs = right.columnarEval(batch)
 
-      if (lhs == null || rhs == null) {
-        ret = null
-      } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) {
-        val l = lhs.asInstanceOf[ColumnVector]
-        val r = rhs.asInstanceOf[ColumnVector]
-        val result = new OnHeapColumnVector(batch.numRows(), dataType)
-        ret = result
-
-        for (i <- 0 until batch.numRows()) {
-          result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
-        }
-      } else if (rhs.isInstanceOf[ColumnVector]) {
-        val l = lhs.asInstanceOf[Long]
-        val r = rhs.asInstanceOf[ColumnVector]
-        val result = new OnHeapColumnVector(batch.numRows(), dataType)
-        ret = result
-
-        for (i <- 0 until batch.numRows()) {
-          result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
-        }
-      } else if (lhs.isInstanceOf[ColumnVector]) {
-        val l = lhs.asInstanceOf[ColumnVector]
-        val r = rhs.asInstanceOf[Long]
-        val result = new OnHeapColumnVector(batch.numRows(), dataType)
-        ret = result
-
-        for (i <- 0 until batch.numRows()) {
-          result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
-        }
-      } else {
-        ret = nullSafeEval(lhs, rhs)
+      (lhs, rhs) match {
+        case (null, null) =>
+          ret = null
+        case (l: ColumnVector, r: ColumnVector) =>
+          val result = new OnHeapColumnVector(batch.numRows(), dataType)
+          ret = result
+
+          for (i <- 0 until batch.numRows()) {
+            result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
+          }
+        case (l: Long, r: ColumnVector) =>
+          val result = new OnHeapColumnVector(batch.numRows(), dataType)
+          ret = result
+
+          for (i <- 0 until batch.numRows()) {
+            result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
+          }
+        case (l: ColumnVector, r: Long) =>
+          val result = new OnHeapColumnVector(batch.numRows(), dataType)
+          ret = result
+
+          for (i <- 0 until batch.numRows()) {
+            result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
+          }
+        case  (l, r) =>
+          ret = nullSafeEval(l, r)
       }
     } finally {
       if (lhs != null && lhs.isInstanceOf[ColumnVector]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index a8b4856..f27a249 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -402,13 +402,12 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
         assert(b.buildSide === buildSide)
       case w: WholeStageCodegenExec =>
         assert(w.children.head.getClass.getSimpleName === joinMethod)
-        if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) {
-          assert(
-            w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide)
-        } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) {
-          assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
-        } else {
-          fail()
+        w.children.head match {
+          case bnlj: BroadcastNestedLoopJoinExec =>
+            assert(bnlj.buildSide === buildSide)
+          case bhj: BroadcastHashJoinExec =>
+            assert(bhj.buildSide === buildSide)
+          case _ => fail()
         }
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index ff182b5..2bb43ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -528,8 +528,10 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with
           verify(triggerClock.isInstanceOf[SystemClock]
             || triggerClock.isInstanceOf[StreamManualClock],
             "Use either SystemClock or StreamManualClock to start the stream")
-          if (triggerClock.isInstanceOf[StreamManualClock]) {
-            manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
+          triggerClock match {
+            case clock: StreamManualClock =>
+              manualClockExpectedTime = clock.getTimeMillis()
+            case _ =>
           }
           val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
 
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 828f987..671b80f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -316,12 +316,12 @@ private[hive] class IsolatedClientLoader(
         .asInstanceOf[HiveClient]
     } catch {
       case e: InvocationTargetException =>
-        if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
-          val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
-          throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError(
-            cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e)
-        } else {
-          throw e
+        e.getCause match {
+          case cnf: NoClassDefFoundError =>
+            throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError(
+              cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e)
+          case _ =>
+            throw e
         }
     } finally {
       Thread.currentThread.setContextClassLoader(origLoader)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 8008a5c..282946dd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -204,10 +204,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
     // If manual clock is being used for testing, then
     // either set the manual clock to the last checkpointed time,
     // or if the property is defined set it to that time
-    if (clock.isInstanceOf[ManualClock]) {
-      val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds
-      val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP)
-      clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime)
+    clock match {
+      case manualClock: ManualClock =>
+        val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds
+        val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP)
+        manualClock.setTime(lastTime + jumpTime)
+      case _ => // do nothing
     }
 
     val batchDuration = ssc.graph.batchDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 4224cef..8069e79 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -296,16 +296,17 @@ private[streaming] class OpenHashMapBasedStateMap[K, S](
     var parentSessionLoopDone = false
     while(!parentSessionLoopDone) {
       val obj = inputStream.readObject()
-      if (obj.isInstanceOf[LimitMarker]) {
-        parentSessionLoopDone = true
-        val expectedCount = obj.asInstanceOf[LimitMarker].num
-        assert(expectedCount == newParentSessionStore.deltaMap.size)
-      } else {
-        val key = obj.asInstanceOf[K]
-        val state = inputStream.readObject().asInstanceOf[S]
-        val updateTime = inputStream.readLong()
-        newParentSessionStore.deltaMap.update(
-          key, StateInfo(state, updateTime, deleted = false))
+      obj match {
+        case marker: LimitMarker =>
+          parentSessionLoopDone = true
+          val expectedCount = marker.num
+          assert(expectedCount == newParentSessionStore.deltaMap.size)
+        case _ =>
+          val key = obj.asInstanceOf[K]
+          val state = inputStream.readObject().asInstanceOf[S]
+          val updateTime = inputStream.readLong()
+          newParentSessionStore.deltaMap.update(
+            key, StateInfo(state, updateTime, deleted = false))
       }
     }
     parentStateMap = newParentSessionStore

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org