You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by sr...@apache.org on 2022/01/15 14:55:27 UTC
[spark] branch master updated: [SPARK-37854][CORE] Replace type check with pattern matching in Spark code
This is an automated email from the ASF dual-hosted git repository.
srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c7c51bc [SPARK-37854][CORE] Replace type check with pattern matching in Spark code
c7c51bc is described below
commit c7c51bcab5cb067d36bccf789e0e4ad7f37ffb7c
Author: yangjie01 <ya...@baidu.com>
AuthorDate: Sat Jan 15 08:54:16 2022 -0600
[SPARK-37854][CORE] Replace type check with pattern matching in Spark code
### What changes were proposed in this pull request?
There are many method use `isInstanceOf + asInstanceOf` for type conversion in Spark code now, the main change of this pr is replace `type check` with `pattern matching` for code simplification.
### Why are the changes needed?
Code simplification
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Pass GA
Closes #35154 from LuciferYang/SPARK-37854.
Authored-by: yangjie01 <ya...@baidu.com>
Signed-off-by: Sean Owen <sr...@gmail.com>
---
.../main/scala/org/apache/spark/TestUtils.scala | 36 ++++++------
.../main/scala/org/apache/spark/api/r/SerDe.scala | 12 ++--
.../spark/internal/config/ConfigBuilder.scala | 18 +++---
.../scala/org/apache/spark/rdd/HadoopRDD.scala | 64 +++++++++++-----------
.../main/scala/org/apache/spark/rdd/PipedRDD.scala | 7 ++-
core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++-
.../main/scala/org/apache/spark/util/Utils.scala | 38 ++++++-------
.../storage/ShuffleBlockFetcherIteratorSuite.scala | 10 ++--
.../org/apache/spark/util/FileAppenderSuite.scala | 17 +++---
.../scala/org/apache/spark/util/UtilsSuite.scala | 19 ++++---
.../apache/spark/examples/mllib/LDAExample.scala | 11 ++--
.../spark/mllib/api/python/PythonMLLibAPI.scala | 12 ++--
.../expressions/aggregate/Percentile.scala | 14 ++---
.../apache/spark/sql/catalyst/trees/TreeNode.scala | 7 +--
.../sql/catalyst/encoders/RowEncoderSuite.scala | 11 ++--
.../sql/execution/columnar/ColumnAccessor.scala | 10 ++--
.../spark/sql/execution/columnar/ColumnType.scala | 50 +++++++++--------
.../sql/execution/datasources/FileScanRDD.scala | 19 ++++---
.../org/apache/spark/sql/jdbc/H2Dialect.scala | 30 +++++-----
.../spark/sql/SparkSessionExtensionSuite.scala | 57 +++++++++----------
.../sql/execution/joins/BroadcastJoinSuite.scala | 13 ++---
.../apache/spark/sql/streaming/StreamTest.scala | 6 +-
.../sql/hive/client/IsolatedClientLoader.scala | 12 ++--
.../spark/streaming/scheduler/JobGenerator.scala | 10 ++--
.../org/apache/spark/streaming/util/StateMap.scala | 21 +++----
25 files changed, 263 insertions(+), 249 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 20159af..d2af955 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -337,22 +337,26 @@ private[spark] object TestUtils {
connection.setRequestMethod(method)
headers.foreach { case (k, v) => connection.setRequestProperty(k, v) }
- // Disable cert and host name validation for HTTPS tests.
- if (connection.isInstanceOf[HttpsURLConnection]) {
- val sslCtx = SSLContext.getInstance("SSL")
- val trustManager = new X509TrustManager {
- override def getAcceptedIssuers(): Array[X509Certificate] = null
- override def checkClientTrusted(x509Certificates: Array[X509Certificate],
- s: String): Unit = {}
- override def checkServerTrusted(x509Certificates: Array[X509Certificate],
- s: String): Unit = {}
- }
- val verifier = new HostnameVerifier() {
- override def verify(hostname: String, session: SSLSession): Boolean = true
- }
- sslCtx.init(null, Array(trustManager), new SecureRandom())
- connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory())
- connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier)
+ connection match {
+ // Disable cert and host name validation for HTTPS tests.
+ case httpConnection: HttpsURLConnection =>
+ val sslCtx = SSLContext.getInstance("SSL")
+ val trustManager = new X509TrustManager {
+ override def getAcceptedIssuers: Array[X509Certificate] = null
+
+ override def checkClientTrusted(x509Certificates: Array[X509Certificate],
+ s: String): Unit = {}
+
+ override def checkServerTrusted(x509Certificates: Array[X509Certificate],
+ s: String): Unit = {}
+ }
+ val verifier = new HostnameVerifier() {
+ override def verify(hostname: String, session: SSLSession): Boolean = true
+ }
+ sslCtx.init(null, Array(trustManager), new SecureRandom())
+ httpConnection.setSSLSocketFactory(sslCtx.getSocketFactory)
+ httpConnection.setHostnameVerifier(verifier)
+ case _ => // do nothing
}
try {
diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
index 9172038..f9f8c56 100644
--- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala
@@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets
import java.sql.{Date, Time, Timestamp}
import scala.collection.JavaConverters._
-import scala.collection.mutable.WrappedArray
+import scala.collection.mutable
/**
* Utility functions to serialize, deserialize objects to / from R
@@ -303,12 +303,10 @@ private[spark] object SerDe {
// Convert ArrayType collected from DataFrame to Java array
// Collected data of ArrayType from a DataFrame is observed to be of
// type "scala.collection.mutable.WrappedArray"
- val value =
- if (obj.isInstanceOf[WrappedArray[_]]) {
- obj.asInstanceOf[WrappedArray[_]].toArray
- } else {
- obj
- }
+ val value = obj match {
+ case wa: mutable.WrappedArray[_] => wa.array
+ case other => other
+ }
value match {
case v: java.lang.Character =>
diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
index 38e057b..e319026 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala
@@ -140,15 +140,15 @@ private[spark] class TypedConfigBuilder[T](
def createWithDefault(default: T): ConfigEntry[T] = {
// Treat "String" as a special case, so that both createWithDefault and createWithDefaultString
// behave the same w.r.t. variable expansion of default values.
- if (default.isInstanceOf[String]) {
- createWithDefaultString(default.asInstanceOf[String])
- } else {
- val transformedDefault = converter(stringConverter(default))
- val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey,
- parent._prependSeparator, parent._alternatives, transformedDefault, converter,
- stringConverter, parent._doc, parent._public, parent._version)
- parent._onCreate.foreach(_(entry))
- entry
+ default match {
+ case str: String => createWithDefaultString(str)
+ case _ =>
+ val transformedDefault = converter(stringConverter(default))
+ val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey,
+ parent._prependSeparator, parent._alternatives, transformedDefault, converter,
+ stringConverter, parent._doc, parent._public, parent._version)
+ parent._onCreate.foreach(_ (entry))
+ entry
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 7011451..fcc2275 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -61,14 +61,14 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp
* @return a Map with the environment variables and corresponding values, it could be empty
*/
def getPipeEnvVars(): Map[String, String] = {
- val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) {
- val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit]
- // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
- // since it's not removed yet
- Map("map_input_file" -> is.getPath().toString(),
- "mapreduce_map_input_file" -> is.getPath().toString())
- } else {
- Map()
+ val envVars: Map[String, String] = inputSplit.value match {
+ case is: FileSplit =>
+ // map_input_file is deprecated in favor of mapreduce_map_input_file but set both
+ // since it's not removed yet
+ Map("map_input_file" -> is.getPath().toString(),
+ "mapreduce_map_input_file" -> is.getPath().toString())
+ case _ =>
+ Map()
}
envVars
}
@@ -161,29 +161,31 @@ class HadoopRDD[K, V](
newJobConf
}
} else {
- if (conf.isInstanceOf[JobConf]) {
- logDebug("Re-using user-broadcasted JobConf")
- conf.asInstanceOf[JobConf]
- } else {
- Option(HadoopRDD.getCachedMetadata(jobConfCacheKey))
- .map { conf =>
- logDebug("Re-using cached JobConf")
- conf.asInstanceOf[JobConf]
- }
- .getOrElse {
- // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in
- // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
- // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary
- // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097,
- // HADOOP-10456).
- HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
- logDebug("Creating new JobConf and caching it for later re-use")
- val newJobConf = new JobConf(conf)
- initLocalJobConfFuncOpt.foreach(f => f(newJobConf))
- HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
- newJobConf
- }
- }
+ conf match {
+ case jobConf: JobConf =>
+ logDebug("Re-using user-broadcasted JobConf")
+ jobConf
+ case _ =>
+ Option(HadoopRDD.getCachedMetadata(jobConfCacheKey))
+ .map { conf =>
+ logDebug("Re-using cached JobConf")
+ conf.asInstanceOf[JobConf]
+ }
+ .getOrElse {
+ // Create a JobConf that will be cached and used across this RDD's getJobConf()
+ // calls in the local process. The local cache is accessed through
+ // HadoopRDD.putCachedMetadata().
+ // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary
+ // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097,
+ // HADOOP-10456).
+ HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized {
+ logDebug("Creating new JobConf and caching it for later re-use")
+ val newJobConf = new JobConf(conf)
+ initLocalJobConfFuncOpt.foreach(f => f(newJobConf))
+ HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
+ newJobConf
+ }
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 285da04..7e121e9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -72,9 +72,10 @@ private[spark] class PipedRDD[T: ClassTag](
// for compatibility with Hadoop which sets these env variables
// so the user code can access the input filename
- if (split.isInstanceOf[HadoopPartition]) {
- val hadoopSplit = split.asInstanceOf[HadoopPartition]
- currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava)
+ split match {
+ case hadoopSplit: HadoopPartition =>
+ currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava)
+ case _ => // do nothing
}
// When spark.worker.separated.working.directory option is turned on, each
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 4c39d17..7188566 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1764,9 +1764,11 @@ abstract class RDD[T: ClassTag](
* Clean the shuffles & all of its parents.
*/
def cleanEagerly(dep: Dependency[_]): Unit = {
- if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) {
- val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
- cleaner.doCleanupShuffle(shuffleId, blocking)
+ dep match {
+ case dependency: ShuffleDependency[_, _, _] =>
+ val shuffleId = dependency.shuffleId
+ cleaner.doCleanupShuffle(shuffleId, blocking)
+ case _ => // do nothing
}
val rdd = dep.rdd
val rddDepsOpt = rdd.internalDependencies
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 8f3d1de..a9d6180 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -355,26 +355,26 @@ private[spark] object Utils extends Logging {
closeStreams: Boolean = false,
transferToEnabled: Boolean = false): Long = {
tryWithSafeFinally {
- if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]
- && transferToEnabled) {
- // When both streams are File stream, use transferTo to improve copy performance.
- val inChannel = in.asInstanceOf[FileInputStream].getChannel()
- val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
- val size = inChannel.size()
- copyFileStreamNIO(inChannel, outChannel, 0, size)
- size
- } else {
- var count = 0L
- val buf = new Array[Byte](8192)
- var n = 0
- while (n != -1) {
- n = in.read(buf)
- if (n != -1) {
- out.write(buf, 0, n)
- count += n
+ (in, out) match {
+ case (input: FileInputStream, output: FileOutputStream) if transferToEnabled =>
+ // When both streams are File stream, use transferTo to improve copy performance.
+ val inChannel = input.getChannel
+ val outChannel = output.getChannel
+ val size = inChannel.size()
+ copyFileStreamNIO(inChannel, outChannel, 0, size)
+ size
+ case (input, output) =>
+ var count = 0L
+ val buf = new Array[Byte](8192)
+ var n = 0
+ while (n != -1) {
+ n = input.read(buf)
+ if (n != -1) {
+ output.write(buf, 0, n)
+ count += n
+ }
}
- }
- count
+ count
}
} {
if (closeStreams) {
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index afb9a86..56043ea 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -160,10 +160,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
verify(buffer, times(0)).release()
val delegateAccess = PrivateMethod[InputStream](Symbol("delegate"))
var in = wrappedInputStream.invokePrivate(delegateAccess())
- if (in.isInstanceOf[CheckedInputStream]) {
- val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in")
- underlyingInputFiled.setAccessible(true)
- in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream]
+ in match {
+ case stream: CheckedInputStream =>
+ val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in")
+ underlyingInputFiled.setAccessible(true)
+ in = underlyingInputFiled.get(stream).asInstanceOf[InputStream]
+ case _ => // do nothing
}
verify(in, times(0)).close()
wrappedInputStream.close()
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
index 1a2eb69..8ca4bc9 100644
--- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -222,14 +222,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging {
// assert(appender.getClass === classTag[ExpectedAppender].getClass)
assert(appender.getClass.getSimpleName ===
classTag[ExpectedAppender].runtimeClass.getSimpleName)
- if (appender.isInstanceOf[RollingFileAppender]) {
- val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy
- val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) {
- rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis
- } else {
- rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes
- }
- assert(policyParam === expectedRollingPolicyParam)
+ appender match {
+ case rfa: RollingFileAppender =>
+ val rollingPolicy = rfa.rollingPolicy
+ val policyParam = rollingPolicy match {
+ case timeBased: TimeBasedRollingPolicy => timeBased.rolloverIntervalMillis
+ case sizeBased: SizeBasedRollingPolicy => sizeBased.rolloverSizeBytes
+ }
+ assert(policyParam === expectedRollingPolicyParam)
+ case _ => // do nothing
}
testOutputStream.close()
appender.awaitTermination()
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 6117dec..62cd819 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -227,15 +227,16 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
try {
// Get a handle on the buffered data, to make sure memory gets freed once we read past the
// end of it. Need to use reflection to get handle on inner structures for this check
- val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) {
- assert(inputLength < limit)
- mergedStream.asInstanceOf[ChunkedByteBufferInputStream]
- } else {
- assert(inputLength >= limit)
- val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream]
- val fieldValue = getFieldValue(sequenceStream, "in")
- assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream])
- fieldValue.asInstanceOf[ChunkedByteBufferInputStream]
+ val byteBufferInputStream = mergedStream match {
+ case stream: ChunkedByteBufferInputStream =>
+ assert(inputLength < limit)
+ stream
+ case _ =>
+ assert(inputLength >= limit)
+ val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream]
+ val fieldValue = getFieldValue(sequenceStream, "in")
+ assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream])
+ fieldValue.asInstanceOf[ChunkedByteBufferInputStream]
}
(0 until inputLength).foreach { idx =>
assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte])
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index a3006a1..afd529c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -158,11 +158,12 @@ object LDAExample {
println(s"Finished training LDA model. Summary:")
println(s"\t Training time: $elapsed sec")
- if (ldaModel.isInstanceOf[DistributedLDAModel]) {
- val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel]
- val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
- println(s"\t Training data average log likelihood: $avgLogLikelihood")
- println()
+ ldaModel match {
+ case distLDAModel: DistributedLDAModel =>
+ val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble
+ println(s"\t Training data average log likelihood: $avgLogLikelihood")
+ println()
+ case _ => // do nothing
}
// Print the topics, showing the top-weighted terms for each topic.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 80707f0..56aaaa3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -90,12 +90,12 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector): JList[Object] = {
try {
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
- if (model.isInstanceOf[LogisticRegressionModel]) {
- val lrModel = model.asInstanceOf[LogisticRegressionModel]
- List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
- .map(_.asInstanceOf[Object]).asJava
- } else {
- List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
+ model match {
+ case lrModel: LogisticRegressionModel =>
+ List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
+ .map(_.asInstanceOf[Object]).asJava
+ case _ =>
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}
} finally {
data.rdd.unpersist()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 7d3dd0a..a98585e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -198,14 +198,12 @@ case class Percentile(
return Seq.empty
}
- val ordering =
- if (child.dataType.isInstanceOf[NumericType]) {
- child.dataType.asInstanceOf[NumericType].ordering
- } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) {
- child.dataType.asInstanceOf[YearMonthIntervalType].ordering
- } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) {
- child.dataType.asInstanceOf[DayTimeIntervalType].ordering
- }
+ val ordering = child.dataType match {
+ case numericType: NumericType => numericType.ordering
+ case intervalType: YearMonthIntervalType => intervalType.ordering
+ case intervalType: DayTimeIntervalType => intervalType.ordering
+ case otherType => QueryExecutionErrors.unsupportedTypeError(otherType)
+ }
val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]])
val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f78bbbf..9e50be3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -341,10 +341,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre
// This is a temporary solution, we will change the type of children to IndexedSeq in a
// followup PR
private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = {
- if (seq.isInstanceOf[IndexedSeq[BaseType]]) {
- seq.asInstanceOf[IndexedSeq[BaseType]]
- } else {
- seq.toIndexedSeq
+ seq match {
+ case types: IndexedSeq[BaseType] => types
+ case other => other.toIndexedSeq
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 1a42784..44b06d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -31,12 +31,11 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM
class ExamplePoint(val x: Double, val y: Double) extends Serializable {
override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt
override def equals(that: Any): Boolean = {
- if (that.isInstanceOf[ExamplePoint]) {
- val e = that.asInstanceOf[ExamplePoint]
- (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
- (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
- } else {
- false
+ that match {
+ case e: ExamplePoint =>
+ (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) &&
+ (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity))
+ case _ => false
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
index 2f68e89..fa7140b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala
@@ -158,11 +158,11 @@ private[sql] object ColumnAccessor {
def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int):
Unit = {
- if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) {
- val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]]
- nativeAccessor.decompress(columnVector, numRows)
- } else {
- throw QueryExecutionErrors.notSupportNonPrimitiveTypeError()
+ columnAccessor match {
+ case nativeAccessor: NativeColumnAccessor[_] =>
+ nativeAccessor.decompress(columnVector, numRows)
+ case _ =>
+ throw QueryExecutionErrors.notSupportNonPrimitiveTypeError()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
index 419dcc6..9b4c136 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala
@@ -473,23 +473,25 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
// copy the bytes from ByteBuffer to UnsafeRow
override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
- if (row.isInstanceOf[MutableUnsafeRow]) {
- val numBytes = buffer.getInt
- val cursor = buffer.position()
- buffer.position(cursor + numBytes)
- row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(),
- buffer.arrayOffset() + cursor, numBytes)
- } else {
- setField(row, ordinal, extract(buffer))
+ row match {
+ case mutable: MutableUnsafeRow =>
+ val numBytes = buffer.getInt
+ val cursor = buffer.position()
+ buffer.position(cursor + numBytes)
+ mutable.writer.write(ordinal, buffer.array(),
+ buffer.arrayOffset() + cursor, numBytes)
+ case _ =>
+ setField(row, ordinal, extract(buffer))
}
}
// copy the bytes from UnsafeRow to ByteBuffer
override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = {
- if (row.isInstanceOf[UnsafeRow]) {
- row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer)
- } else {
- super.append(row, ordinal, buffer)
+ row match {
+ case unsafe: UnsafeRow =>
+ unsafe.writeFieldTo(ordinal, buffer)
+ case _ =>
+ super.append(row, ordinal, buffer)
}
}
}
@@ -514,10 +516,11 @@ private[columnar] object STRING
}
override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = {
- if (row.isInstanceOf[MutableUnsafeRow]) {
- row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value)
- } else {
- row.update(ordinal, value.clone())
+ row match {
+ case mutable: MutableUnsafeRow =>
+ mutable.writer.write(ordinal, value)
+ case _ =>
+ row.update(ordinal, value.clone())
}
}
@@ -792,13 +795,14 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval]
// copy the bytes from ByteBuffer to UnsafeRow
override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = {
- if (row.isInstanceOf[MutableUnsafeRow]) {
- val cursor = buffer.position()
- buffer.position(cursor + defaultSize)
- row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(),
- buffer.arrayOffset() + cursor, defaultSize)
- } else {
- setField(row, ordinal, extract(buffer))
+ row match {
+ case mutable: MutableUnsafeRow =>
+ val cursor = buffer.position()
+ buffer.position(cursor + defaultSize)
+ mutable.writer.write(ordinal, buffer.array(),
+ buffer.arrayOffset() + cursor, defaultSize)
+ case _ =>
+ setField(row, ordinal, extract(buffer))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index 47f279b..5baa597 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -214,16 +214,17 @@ class FileScanRDD(
val nextElement = currentIterator.next()
// TODO: we should have a better separation of row based and batch based scan, so that we
// don't need to run this `if` for every record.
- if (nextElement.isInstanceOf[ColumnarBatch]) {
- incTaskInputMetricsBytesRead()
- inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows())
- } else {
- // too costly to update every record
- if (inputMetrics.recordsRead %
- SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+ nextElement match {
+ case batch: ColumnarBatch =>
incTaskInputMetricsBytesRead()
- }
- inputMetrics.incRecordsRead(1)
+ inputMetrics.incRecordsRead(batch.numRows())
+ case _ =>
+ // too costly to update every record
+ if (inputMetrics.recordsRead %
+ SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) {
+ incTaskInputMetricsBytesRead()
+ }
+ inputMetrics.incRecordsRead(1)
}
addMetadataColumnsIfNeeded(nextElement)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 1f422e5..7bd51f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -65,20 +65,22 @@ private object H2Dialect extends JdbcDialect {
}
override def classifyException(message: String, e: Throwable): AnalysisException = {
- if (e.isInstanceOf[SQLException]) {
- // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
- e.asInstanceOf[SQLException].getErrorCode match {
- // TABLE_OR_VIEW_ALREADY_EXISTS_1
- case 42101 =>
- throw new TableAlreadyExistsException(message, cause = Some(e))
- // TABLE_OR_VIEW_NOT_FOUND_1
- case 42102 =>
- throw new NoSuchTableException(message, cause = Some(e))
- // SCHEMA_NOT_FOUND_1
- case 90079 =>
- throw new NoSuchNamespaceException(message, cause = Some(e))
- case _ =>
- }
+ e match {
+ case exception: SQLException =>
+ // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
+ exception.getErrorCode match {
+ // TABLE_OR_VIEW_ALREADY_EXISTS_1
+ case 42101 =>
+ throw new TableAlreadyExistsException(message, cause = Some(e))
+ // TABLE_OR_VIEW_NOT_FOUND_1
+ case 42102 =>
+ throw NoSuchTableException(message, cause = Some(e))
+ // SCHEMA_NOT_FOUND_1
+ case 90079 =>
+ throw NoSuchNamespaceException(message, cause = Some(e))
+ case _ => // do nothing
+ }
+ case _ => // do nothing
}
super.classifyException(message, e)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 4994968..3577812 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -725,37 +725,32 @@ class BrokenColumnarAdd(
lhs = left.columnarEval(batch)
rhs = right.columnarEval(batch)
- if (lhs == null || rhs == null) {
- ret = null
- } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) {
- val l = lhs.asInstanceOf[ColumnVector]
- val r = rhs.asInstanceOf[ColumnVector]
- val result = new OnHeapColumnVector(batch.numRows(), dataType)
- ret = result
-
- for (i <- 0 until batch.numRows()) {
- result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
- }
- } else if (rhs.isInstanceOf[ColumnVector]) {
- val l = lhs.asInstanceOf[Long]
- val r = rhs.asInstanceOf[ColumnVector]
- val result = new OnHeapColumnVector(batch.numRows(), dataType)
- ret = result
-
- for (i <- 0 until batch.numRows()) {
- result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
- }
- } else if (lhs.isInstanceOf[ColumnVector]) {
- val l = lhs.asInstanceOf[ColumnVector]
- val r = rhs.asInstanceOf[Long]
- val result = new OnHeapColumnVector(batch.numRows(), dataType)
- ret = result
-
- for (i <- 0 until batch.numRows()) {
- result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
- }
- } else {
- ret = nullSafeEval(lhs, rhs)
+ (lhs, rhs) match {
+ case (null, null) =>
+ ret = null
+ case (l: ColumnVector, r: ColumnVector) =>
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add
+ }
+ case (l: Long, r: ColumnVector) =>
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add
+ }
+ case (l: ColumnVector, r: Long) =>
+ val result = new OnHeapColumnVector(batch.numRows(), dataType)
+ ret = result
+
+ for (i <- 0 until batch.numRows()) {
+ result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add
+ }
+ case (l, r) =>
+ ret = nullSafeEval(l, r)
}
} finally {
if (lhs != null && lhs.isInstanceOf[ColumnVector]) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index a8b4856..f27a249 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -402,13 +402,12 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
assert(b.buildSide === buildSide)
case w: WholeStageCodegenExec =>
assert(w.children.head.getClass.getSimpleName === joinMethod)
- if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) {
- assert(
- w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide)
- } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) {
- assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide)
- } else {
- fail()
+ w.children.head match {
+ case bnlj: BroadcastNestedLoopJoinExec =>
+ assert(bnlj.buildSide === buildSide)
+ case bhj: BroadcastHashJoinExec =>
+ assert(bhj.buildSide === buildSide)
+ case _ => fail()
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index ff182b5..2bb43ec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -528,8 +528,10 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with
verify(triggerClock.isInstanceOf[SystemClock]
|| triggerClock.isInstanceOf[StreamManualClock],
"Use either SystemClock or StreamManualClock to start the stream")
- if (triggerClock.isInstanceOf[StreamManualClock]) {
- manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis()
+ triggerClock match {
+ case clock: StreamManualClock =>
+ manualClockExpectedTime = clock.getTimeMillis()
+ case _ =>
}
val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
index 828f987..671b80f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala
@@ -316,12 +316,12 @@ private[hive] class IsolatedClientLoader(
.asInstanceOf[HiveClient]
} catch {
case e: InvocationTargetException =>
- if (e.getCause().isInstanceOf[NoClassDefFoundError]) {
- val cnf = e.getCause().asInstanceOf[NoClassDefFoundError]
- throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError(
- cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e)
- } else {
- throw e
+ e.getCause match {
+ case cnf: NoClassDefFoundError =>
+ throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError(
+ cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e)
+ case _ =>
+ throw e
}
} finally {
Thread.currentThread.setContextClassLoader(origLoader)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 8008a5c..282946dd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -204,10 +204,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
// If manual clock is being used for testing, then
// either set the manual clock to the last checkpointed time,
// or if the property is defined set it to that time
- if (clock.isInstanceOf[ManualClock]) {
- val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds
- val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP)
- clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime)
+ clock match {
+ case manualClock: ManualClock =>
+ val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds
+ val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP)
+ manualClock.setTime(lastTime + jumpTime)
+ case _ => // do nothing
}
val batchDuration = ssc.graph.batchDuration
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
index 4224cef..8069e79 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala
@@ -296,16 +296,17 @@ private[streaming] class OpenHashMapBasedStateMap[K, S](
var parentSessionLoopDone = false
while(!parentSessionLoopDone) {
val obj = inputStream.readObject()
- if (obj.isInstanceOf[LimitMarker]) {
- parentSessionLoopDone = true
- val expectedCount = obj.asInstanceOf[LimitMarker].num
- assert(expectedCount == newParentSessionStore.deltaMap.size)
- } else {
- val key = obj.asInstanceOf[K]
- val state = inputStream.readObject().asInstanceOf[S]
- val updateTime = inputStream.readLong()
- newParentSessionStore.deltaMap.update(
- key, StateInfo(state, updateTime, deleted = false))
+ obj match {
+ case marker: LimitMarker =>
+ parentSessionLoopDone = true
+ val expectedCount = marker.num
+ assert(expectedCount == newParentSessionStore.deltaMap.size)
+ case _ =>
+ val key = obj.asInstanceOf[K]
+ val state = inputStream.readObject().asInstanceOf[S]
+ val updateTime = inputStream.readLong()
+ newParentSessionStore.deltaMap.update(
+ key, StateInfo(state, updateTime, deleted = false))
}
}
parentStateMap = newParentSessionStore
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org