This is an automated email from the ASF dual-hosted git repository. srowen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new c7c51bc [SPARK-37854][CORE] Replace type check with pattern matching in Spark code c7c51bc is described below commit c7c51bcab5cb067d36bccf789e0e4ad7f37ffb7c Author: yangjie01 <yangji...@baidu.com> AuthorDate: Sat Jan 15 08:54:16 2022 -0600 [SPARK-37854][CORE] Replace type check with pattern matching in Spark code ### What changes were proposed in this pull request? There are many method use `isInstanceOf + asInstanceOf` for type conversion in Spark code now, the main change of this pr is replace `type check` with `pattern matching` for code simplification. ### Why are the changes needed? Code simplification ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Pass GA Closes #35154 from LuciferYang/SPARK-37854. Authored-by: yangjie01 <yangji...@baidu.com> Signed-off-by: Sean Owen <sro...@gmail.com> --- .../main/scala/org/apache/spark/TestUtils.scala | 36 ++++++------ .../main/scala/org/apache/spark/api/r/SerDe.scala | 12 ++-- .../spark/internal/config/ConfigBuilder.scala | 18 +++--- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 64 +++++++++++----------- .../main/scala/org/apache/spark/rdd/PipedRDD.scala | 7 ++- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 8 ++- .../main/scala/org/apache/spark/util/Utils.scala | 38 ++++++------- .../storage/ShuffleBlockFetcherIteratorSuite.scala | 10 ++-- .../org/apache/spark/util/FileAppenderSuite.scala | 17 +++--- .../scala/org/apache/spark/util/UtilsSuite.scala | 19 ++++--- .../apache/spark/examples/mllib/LDAExample.scala | 11 ++-- .../spark/mllib/api/python/PythonMLLibAPI.scala | 12 ++-- .../expressions/aggregate/Percentile.scala | 14 ++--- .../apache/spark/sql/catalyst/trees/TreeNode.scala | 7 +-- .../sql/catalyst/encoders/RowEncoderSuite.scala | 11 ++-- .../sql/execution/columnar/ColumnAccessor.scala | 10 ++-- .../spark/sql/execution/columnar/ColumnType.scala | 50 +++++++++-------- .../sql/execution/datasources/FileScanRDD.scala | 19 ++++--- .../org/apache/spark/sql/jdbc/H2Dialect.scala | 30 +++++----- .../spark/sql/SparkSessionExtensionSuite.scala | 57 +++++++++---------- .../sql/execution/joins/BroadcastJoinSuite.scala | 13 ++--- .../apache/spark/sql/streaming/StreamTest.scala | 6 +- .../sql/hive/client/IsolatedClientLoader.scala | 12 ++-- .../spark/streaming/scheduler/JobGenerator.scala | 10 ++-- .../org/apache/spark/streaming/util/StateMap.scala | 21 +++---- 25 files changed, 263 insertions(+), 249 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index 20159af..d2af955 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -337,22 +337,26 @@ private[spark] object TestUtils { connection.setRequestMethod(method) headers.foreach { case (k, v) => connection.setRequestProperty(k, v) } - // Disable cert and host name validation for HTTPS tests. - if (connection.isInstanceOf[HttpsURLConnection]) { - val sslCtx = SSLContext.getInstance("SSL") - val trustManager = new X509TrustManager { - override def getAcceptedIssuers(): Array[X509Certificate] = null - override def checkClientTrusted(x509Certificates: Array[X509Certificate], - s: String): Unit = {} - override def checkServerTrusted(x509Certificates: Array[X509Certificate], - s: String): Unit = {} - } - val verifier = new HostnameVerifier() { - override def verify(hostname: String, session: SSLSession): Boolean = true - } - sslCtx.init(null, Array(trustManager), new SecureRandom()) - connection.asInstanceOf[HttpsURLConnection].setSSLSocketFactory(sslCtx.getSocketFactory()) - connection.asInstanceOf[HttpsURLConnection].setHostnameVerifier(verifier) + connection match { + // Disable cert and host name validation for HTTPS tests. + case httpConnection: HttpsURLConnection => + val sslCtx = SSLContext.getInstance("SSL") + val trustManager = new X509TrustManager { + override def getAcceptedIssuers: Array[X509Certificate] = null + + override def checkClientTrusted(x509Certificates: Array[X509Certificate], + s: String): Unit = {} + + override def checkServerTrusted(x509Certificates: Array[X509Certificate], + s: String): Unit = {} + } + val verifier = new HostnameVerifier() { + override def verify(hostname: String, session: SSLSession): Boolean = true + } + sslCtx.init(null, Array(trustManager), new SecureRandom()) + httpConnection.setSSLSocketFactory(sslCtx.getSocketFactory) + httpConnection.setHostnameVerifier(verifier) + case _ => // do nothing } try { diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 9172038..f9f8c56 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -22,7 +22,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Time, Timestamp} import scala.collection.JavaConverters._ -import scala.collection.mutable.WrappedArray +import scala.collection.mutable /** * Utility functions to serialize, deserialize objects to / from R @@ -303,12 +303,10 @@ private[spark] object SerDe { // Convert ArrayType collected from DataFrame to Java array // Collected data of ArrayType from a DataFrame is observed to be of // type "scala.collection.mutable.WrappedArray" - val value = - if (obj.isInstanceOf[WrappedArray[_]]) { - obj.asInstanceOf[WrappedArray[_]].toArray - } else { - obj - } + val value = obj match { + case wa: mutable.WrappedArray[_] => wa.array + case other => other + } value match { case v: java.lang.Character => diff --git a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index 38e057b..e319026 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -140,15 +140,15 @@ private[spark] class TypedConfigBuilder[T]( def createWithDefault(default: T): ConfigEntry[T] = { // Treat "String" as a special case, so that both createWithDefault and createWithDefaultString // behave the same w.r.t. variable expansion of default values. - if (default.isInstanceOf[String]) { - createWithDefaultString(default.asInstanceOf[String]) - } else { - val transformedDefault = converter(stringConverter(default)) - val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey, - parent._prependSeparator, parent._alternatives, transformedDefault, converter, - stringConverter, parent._doc, parent._public, parent._version) - parent._onCreate.foreach(_(entry)) - entry + default match { + case str: String => createWithDefaultString(str) + case _ => + val transformedDefault = converter(stringConverter(default)) + val entry = new ConfigEntryWithDefault[T](parent.key, parent._prependedKey, + parent._prependSeparator, parent._alternatives, transformedDefault, converter, + stringConverter, parent._doc, parent._public, parent._version) + parent._onCreate.foreach(_ (entry)) + entry } } diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 7011451..fcc2275 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -61,14 +61,14 @@ private[spark] class HadoopPartition(rddId: Int, override val index: Int, s: Inp * @return a Map with the environment variables and corresponding values, it could be empty */ def getPipeEnvVars(): Map[String, String] = { - val envVars: Map[String, String] = if (inputSplit.value.isInstanceOf[FileSplit]) { - val is: FileSplit = inputSplit.value.asInstanceOf[FileSplit] - // map_input_file is deprecated in favor of mapreduce_map_input_file but set both - // since it's not removed yet - Map("map_input_file" -> is.getPath().toString(), - "mapreduce_map_input_file" -> is.getPath().toString()) - } else { - Map() + val envVars: Map[String, String] = inputSplit.value match { + case is: FileSplit => + // map_input_file is deprecated in favor of mapreduce_map_input_file but set both + // since it's not removed yet + Map("map_input_file" -> is.getPath().toString(), + "mapreduce_map_input_file" -> is.getPath().toString()) + case _ => + Map() } envVars } @@ -161,29 +161,31 @@ class HadoopRDD[K, V]( newJobConf } } else { - if (conf.isInstanceOf[JobConf]) { - logDebug("Re-using user-broadcasted JobConf") - conf.asInstanceOf[JobConf] - } else { - Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) - .map { conf => - logDebug("Re-using cached JobConf") - conf.asInstanceOf[JobConf] - } - .getOrElse { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in - // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary - // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, - // HADOOP-10456). - HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf - } - } + conf match { + case jobConf: JobConf => + logDebug("Re-using user-broadcasted JobConf") + jobConf + case _ => + Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() + // calls in the local process. The local cache is accessed through + // HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } + } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index 285da04..7e121e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -72,9 +72,10 @@ private[spark] class PipedRDD[T: ClassTag]( // for compatibility with Hadoop which sets these env variables // so the user code can access the input filename - if (split.isInstanceOf[HadoopPartition]) { - val hadoopSplit = split.asInstanceOf[HadoopPartition] - currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) + split match { + case hadoopSplit: HadoopPartition => + currentEnvVars.putAll(hadoopSplit.getPipeEnvVars().asJava) + case _ => // do nothing } // When spark.worker.separated.working.directory option is turned on, each diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 4c39d17..7188566 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1764,9 +1764,11 @@ abstract class RDD[T: ClassTag]( * Clean the shuffles & all of its parents. */ def cleanEagerly(dep: Dependency[_]): Unit = { - if (dep.isInstanceOf[ShuffleDependency[_, _, _]]) { - val shuffleId = dep.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId - cleaner.doCleanupShuffle(shuffleId, blocking) + dep match { + case dependency: ShuffleDependency[_, _, _] => + val shuffleId = dependency.shuffleId + cleaner.doCleanupShuffle(shuffleId, blocking) + case _ => // do nothing } val rdd = dep.rdd val rddDepsOpt = rdd.internalDependencies diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8f3d1de..a9d6180 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -355,26 +355,26 @@ private[spark] object Utils extends Logging { closeStreams: Boolean = false, transferToEnabled: Boolean = false): Long = { tryWithSafeFinally { - if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream] - && transferToEnabled) { - // When both streams are File stream, use transferTo to improve copy performance. - val inChannel = in.asInstanceOf[FileInputStream].getChannel() - val outChannel = out.asInstanceOf[FileOutputStream].getChannel() - val size = inChannel.size() - copyFileStreamNIO(inChannel, outChannel, 0, size) - size - } else { - var count = 0L - val buf = new Array[Byte](8192) - var n = 0 - while (n != -1) { - n = in.read(buf) - if (n != -1) { - out.write(buf, 0, n) - count += n + (in, out) match { + case (input: FileInputStream, output: FileOutputStream) if transferToEnabled => + // When both streams are File stream, use transferTo to improve copy performance. + val inChannel = input.getChannel + val outChannel = output.getChannel + val size = inChannel.size() + copyFileStreamNIO(inChannel, outChannel, 0, size) + size + case (input, output) => + var count = 0L + val buf = new Array[Byte](8192) + var n = 0 + while (n != -1) { + n = input.read(buf) + if (n != -1) { + output.write(buf, 0, n) + count += n + } } - } - count + count } } { if (closeStreams) { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index afb9a86..56043ea 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -160,10 +160,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(buffer, times(0)).release() val delegateAccess = PrivateMethod[InputStream](Symbol("delegate")) var in = wrappedInputStream.invokePrivate(delegateAccess()) - if (in.isInstanceOf[CheckedInputStream]) { - val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in") - underlyingInputFiled.setAccessible(true) - in = underlyingInputFiled.get(in.asInstanceOf[CheckedInputStream]).asInstanceOf[InputStream] + in match { + case stream: CheckedInputStream => + val underlyingInputFiled = classOf[CheckedInputStream].getSuperclass.getDeclaredField("in") + underlyingInputFiled.setAccessible(true) + in = underlyingInputFiled.get(stream).asInstanceOf[InputStream] + case _ => // do nothing } verify(in, times(0)).close() wrappedInputStream.close() diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 1a2eb69..8ca4bc9 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -222,14 +222,15 @@ class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { // assert(appender.getClass === classTag[ExpectedAppender].getClass) assert(appender.getClass.getSimpleName === classTag[ExpectedAppender].runtimeClass.getSimpleName) - if (appender.isInstanceOf[RollingFileAppender]) { - val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy - val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) { - rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis - } else { - rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes - } - assert(policyParam === expectedRollingPolicyParam) + appender match { + case rfa: RollingFileAppender => + val rollingPolicy = rfa.rollingPolicy + val policyParam = rollingPolicy match { + case timeBased: TimeBasedRollingPolicy => timeBased.rolloverIntervalMillis + case sizeBased: SizeBasedRollingPolicy => sizeBased.rolloverSizeBytes + } + assert(policyParam === expectedRollingPolicyParam) + case _ => // do nothing } testOutputStream.close() appender.awaitTermination() diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 6117dec..62cd819 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -227,15 +227,16 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { try { // Get a handle on the buffered data, to make sure memory gets freed once we read past the // end of it. Need to use reflection to get handle on inner structures for this check - val byteBufferInputStream = if (mergedStream.isInstanceOf[ChunkedByteBufferInputStream]) { - assert(inputLength < limit) - mergedStream.asInstanceOf[ChunkedByteBufferInputStream] - } else { - assert(inputLength >= limit) - val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] - val fieldValue = getFieldValue(sequenceStream, "in") - assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) - fieldValue.asInstanceOf[ChunkedByteBufferInputStream] + val byteBufferInputStream = mergedStream match { + case stream: ChunkedByteBufferInputStream => + assert(inputLength < limit) + stream + case _ => + assert(inputLength >= limit) + val sequenceStream = mergedStream.asInstanceOf[SequenceInputStream] + val fieldValue = getFieldValue(sequenceStream, "in") + assert(fieldValue.isInstanceOf[ChunkedByteBufferInputStream]) + fieldValue.asInstanceOf[ChunkedByteBufferInputStream] } (0 until inputLength).foreach { idx => assert(bytes(idx) === mergedStream.read().asInstanceOf[Byte]) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index a3006a1..afd529c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -158,11 +158,12 @@ object LDAExample { println(s"Finished training LDA model. Summary:") println(s"\t Training time: $elapsed sec") - if (ldaModel.isInstanceOf[DistributedLDAModel]) { - val distLDAModel = ldaModel.asInstanceOf[DistributedLDAModel] - val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble - println(s"\t Training data average log likelihood: $avgLogLikelihood") - println() + ldaModel match { + case distLDAModel: DistributedLDAModel => + val avgLogLikelihood = distLDAModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + case _ => // do nothing } // Print the topics, showing the top-weighted terms for each topic. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 80707f0..56aaaa3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -90,12 +90,12 @@ private[python] class PythonMLLibAPI extends Serializable { initialWeights: Vector): JList[Object] = { try { val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights) - if (model.isInstanceOf[LogisticRegressionModel]) { - val lrModel = model.asInstanceOf[LogisticRegressionModel] - List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) - .map(_.asInstanceOf[Object]).asJava - } else { - List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava + model match { + case lrModel: LogisticRegressionModel => + List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses) + .map(_.asInstanceOf[Object]).asJava + case _ => + List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava } } finally { data.rdd.unpersist() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala index 7d3dd0a..a98585e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -198,14 +198,12 @@ case class Percentile( return Seq.empty } - val ordering = - if (child.dataType.isInstanceOf[NumericType]) { - child.dataType.asInstanceOf[NumericType].ordering - } else if (child.dataType.isInstanceOf[YearMonthIntervalType]) { - child.dataType.asInstanceOf[YearMonthIntervalType].ordering - } else if (child.dataType.isInstanceOf[DayTimeIntervalType]) { - child.dataType.asInstanceOf[DayTimeIntervalType].ordering - } + val ordering = child.dataType match { + case numericType: NumericType => numericType.ordering + case intervalType: YearMonthIntervalType => intervalType.ordering + case intervalType: DayTimeIntervalType => intervalType.ordering + case otherType => QueryExecutionErrors.unsupportedTypeError(otherType) + } val sortedCounts = buffer.toSeq.sortBy(_._1)(ordering.asInstanceOf[Ordering[AnyRef]]) val accumulatedCounts = sortedCounts.scanLeft((sortedCounts.head._1, 0L)) { case ((key1, count1), (key2, count2)) => (key2, count1 + count2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index f78bbbf..9e50be3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -341,10 +341,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product with Tre // This is a temporary solution, we will change the type of children to IndexedSeq in a // followup PR private def asIndexedSeq(seq: Seq[BaseType]): IndexedSeq[BaseType] = { - if (seq.isInstanceOf[IndexedSeq[BaseType]]) { - seq.asInstanceOf[IndexedSeq[BaseType]] - } else { - seq.toIndexedSeq + seq match { + case types: IndexedSeq[BaseType] => types + case other => other.toIndexedSeq } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 1a42784..44b06d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -31,12 +31,11 @@ import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearM class ExamplePoint(val x: Double, val y: Double) extends Serializable { override def hashCode: Int = 41 * (41 + x.toInt) + y.toInt override def equals(that: Any): Boolean = { - if (that.isInstanceOf[ExamplePoint]) { - val e = that.asInstanceOf[ExamplePoint] - (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && - (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) - } else { - false + that match { + case e: ExamplePoint => + (this.x == e.x || (this.x.isNaN && e.x.isNaN) || (this.x.isInfinity && e.x.isInfinity)) && + (this.y == e.y || (this.y.isNaN && e.y.isNaN) || (this.y.isInfinity && e.y.isInfinity)) + case _ => false } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 2f68e89..fa7140b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -158,11 +158,11 @@ private[sql] object ColumnAccessor { def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int): Unit = { - if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { - val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] - nativeAccessor.decompress(columnVector, numRows) - } else { - throw QueryExecutionErrors.notSupportNonPrimitiveTypeError() + columnAccessor match { + case nativeAccessor: NativeColumnAccessor[_] => + nativeAccessor.decompress(columnVector, numRows) + case _ => + throw QueryExecutionErrors.notSupportNonPrimitiveTypeError() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 419dcc6..9b4c136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -473,23 +473,25 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - val numBytes = buffer.getInt - val cursor = buffer.position() - buffer.position(cursor + numBytes) - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), - buffer.arrayOffset() + cursor, numBytes) - } else { - setField(row, ordinal, extract(buffer)) + row match { + case mutable: MutableUnsafeRow => + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + mutable.writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + case _ => + setField(row, ordinal, extract(buffer)) } } // copy the bytes from UnsafeRow to ByteBuffer override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { - if (row.isInstanceOf[UnsafeRow]) { - row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) - } else { - super.append(row, ordinal, buffer) + row match { + case unsafe: UnsafeRow => + unsafe.writeFieldTo(ordinal, buffer) + case _ => + super.append(row, ordinal, buffer) } } } @@ -514,10 +516,11 @@ private[columnar] object STRING } override def setField(row: InternalRow, ordinal: Int, value: UTF8String): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) - } else { - row.update(ordinal, value.clone()) + row match { + case mutable: MutableUnsafeRow => + mutable.writer.write(ordinal, value) + case _ => + row.update(ordinal, value.clone()) } } @@ -792,13 +795,14 @@ private[columnar] object CALENDAR_INTERVAL extends ColumnType[CalendarInterval] // copy the bytes from ByteBuffer to UnsafeRow override def extract(buffer: ByteBuffer, row: InternalRow, ordinal: Int): Unit = { - if (row.isInstanceOf[MutableUnsafeRow]) { - val cursor = buffer.position() - buffer.position(cursor + defaultSize) - row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), - buffer.arrayOffset() + cursor, defaultSize) - } else { - setField(row, ordinal, extract(buffer)) + row match { + case mutable: MutableUnsafeRow => + val cursor = buffer.position() + buffer.position(cursor + defaultSize) + mutable.writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, defaultSize) + case _ => + setField(row, ordinal, extract(buffer)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 47f279b..5baa597 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -214,16 +214,17 @@ class FileScanRDD( val nextElement = currentIterator.next() // TODO: we should have a better separation of row based and batch based scan, so that we // don't need to run this `if` for every record. - if (nextElement.isInstanceOf[ColumnarBatch]) { - incTaskInputMetricsBytesRead() - inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) - } else { - // too costly to update every record - if (inputMetrics.recordsRead % - SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + nextElement match { + case batch: ColumnarBatch => incTaskInputMetricsBytesRead() - } - inputMetrics.incRecordsRead(1) + inputMetrics.incRecordsRead(batch.numRows()) + case _ => + // too costly to update every record + if (inputMetrics.recordsRead % + SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + incTaskInputMetricsBytesRead() + } + inputMetrics.incRecordsRead(1) } addMetadataColumnsIfNeeded(nextElement) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 1f422e5..7bd51f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -65,20 +65,22 @@ private object H2Dialect extends JdbcDialect { } override def classifyException(message: String, e: Throwable): AnalysisException = { - if (e.isInstanceOf[SQLException]) { - // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html - e.asInstanceOf[SQLException].getErrorCode match { - // TABLE_OR_VIEW_ALREADY_EXISTS_1 - case 42101 => - throw new TableAlreadyExistsException(message, cause = Some(e)) - // TABLE_OR_VIEW_NOT_FOUND_1 - case 42102 => - throw new NoSuchTableException(message, cause = Some(e)) - // SCHEMA_NOT_FOUND_1 - case 90079 => - throw new NoSuchNamespaceException(message, cause = Some(e)) - case _ => - } + e match { + case exception: SQLException => + // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html + exception.getErrorCode match { + // TABLE_OR_VIEW_ALREADY_EXISTS_1 + case 42101 => + throw new TableAlreadyExistsException(message, cause = Some(e)) + // TABLE_OR_VIEW_NOT_FOUND_1 + case 42102 => + throw NoSuchTableException(message, cause = Some(e)) + // SCHEMA_NOT_FOUND_1 + case 90079 => + throw NoSuchNamespaceException(message, cause = Some(e)) + case _ => // do nothing + } + case _ => // do nothing } super.classifyException(message, e) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 4994968..3577812 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -725,37 +725,32 @@ class BrokenColumnarAdd( lhs = left.columnarEval(batch) rhs = right.columnarEval(batch) - if (lhs == null || rhs == null) { - ret = null - } else if (lhs.isInstanceOf[ColumnVector] && rhs.isInstanceOf[ColumnVector]) { - val l = lhs.asInstanceOf[ColumnVector] - val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), dataType) - ret = result - - for (i <- 0 until batch.numRows()) { - result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add - } - } else if (rhs.isInstanceOf[ColumnVector]) { - val l = lhs.asInstanceOf[Long] - val r = rhs.asInstanceOf[ColumnVector] - val result = new OnHeapColumnVector(batch.numRows(), dataType) - ret = result - - for (i <- 0 until batch.numRows()) { - result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add - } - } else if (lhs.isInstanceOf[ColumnVector]) { - val l = lhs.asInstanceOf[ColumnVector] - val r = rhs.asInstanceOf[Long] - val result = new OnHeapColumnVector(batch.numRows(), dataType) - ret = result - - for (i <- 0 until batch.numRows()) { - result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add - } - } else { - ret = nullSafeEval(lhs, rhs) + (lhs, rhs) match { + case (null, null) => + ret = null + case (l: ColumnVector, r: ColumnVector) => + val result = new OnHeapColumnVector(batch.numRows(), dataType) + ret = result + + for (i <- 0 until batch.numRows()) { + result.appendLong(l.getLong(i) + r.getLong(i) + 1) // BUG to show we replaced Add + } + case (l: Long, r: ColumnVector) => + val result = new OnHeapColumnVector(batch.numRows(), dataType) + ret = result + + for (i <- 0 until batch.numRows()) { + result.appendLong(l + r.getLong(i) + 1) // BUG to show we replaced Add + } + case (l: ColumnVector, r: Long) => + val result = new OnHeapColumnVector(batch.numRows(), dataType) + ret = result + + for (i <- 0 until batch.numRows()) { + result.appendLong(l.getLong(i) + r + 1) // BUG to show we replaced Add + } + case (l, r) => + ret = nullSafeEval(l, r) } } finally { if (lhs != null && lhs.isInstanceOf[ColumnVector]) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index a8b4856..f27a249 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -402,13 +402,12 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(b.buildSide === buildSide) case w: WholeStageCodegenExec => assert(w.children.head.getClass.getSimpleName === joinMethod) - if (w.children.head.isInstanceOf[BroadcastNestedLoopJoinExec]) { - assert( - w.children.head.asInstanceOf[BroadcastNestedLoopJoinExec].buildSide === buildSide) - } else if (w.children.head.isInstanceOf[BroadcastHashJoinExec]) { - assert(w.children.head.asInstanceOf[BroadcastHashJoinExec].buildSide === buildSide) - } else { - fail() + w.children.head match { + case bnlj: BroadcastNestedLoopJoinExec => + assert(bnlj.buildSide === buildSide) + case bhj: BroadcastHashJoinExec => + assert(bhj.buildSide === buildSide) + case _ => fail() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index ff182b5..2bb43ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -528,8 +528,10 @@ trait StreamTest extends QueryTest with SharedSparkSession with TimeLimits with verify(triggerClock.isInstanceOf[SystemClock] || triggerClock.isInstanceOf[StreamManualClock], "Use either SystemClock or StreamManualClock to start the stream") - if (triggerClock.isInstanceOf[StreamManualClock]) { - manualClockExpectedTime = triggerClock.asInstanceOf[StreamManualClock].getTimeMillis() + triggerClock match { + case clock: StreamManualClock => + manualClockExpectedTime = clock.getTimeMillis() + case _ => } val metadataRoot = Option(checkpointLocation).getOrElse(defaultCheckpointLocation) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 828f987..671b80f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -316,12 +316,12 @@ private[hive] class IsolatedClientLoader( .asInstanceOf[HiveClient] } catch { case e: InvocationTargetException => - if (e.getCause().isInstanceOf[NoClassDefFoundError]) { - val cnf = e.getCause().asInstanceOf[NoClassDefFoundError] - throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError( - cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e) - } else { - throw e + e.getCause match { + case cnf: NoClassDefFoundError => + throw QueryExecutionErrors.loadHiveClientCausesNoClassDefFoundError( + cnf, execJars, HiveUtils.HIVE_METASTORE_JARS.key, e) + case _ => + throw e } } finally { Thread.currentThread.setContextClassLoader(origLoader) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8008a5c..282946dd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -204,10 +204,12 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging { // If manual clock is being used for testing, then // either set the manual clock to the last checkpointed time, // or if the property is defined set it to that time - if (clock.isInstanceOf[ManualClock]) { - val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP) - clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) + clock match { + case manualClock: ManualClock => + val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds + val jumpTime = ssc.sc.conf.get(StreamingConf.MANUAL_CLOCK_JUMP) + manualClock.setTime(lastTime + jumpTime) + case _ => // do nothing } val batchDuration = ssc.graph.batchDuration diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 4224cef..8069e79 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -296,16 +296,17 @@ private[streaming] class OpenHashMapBasedStateMap[K, S]( var parentSessionLoopDone = false while(!parentSessionLoopDone) { val obj = inputStream.readObject() - if (obj.isInstanceOf[LimitMarker]) { - parentSessionLoopDone = true - val expectedCount = obj.asInstanceOf[LimitMarker].num - assert(expectedCount == newParentSessionStore.deltaMap.size) - } else { - val key = obj.asInstanceOf[K] - val state = inputStream.readObject().asInstanceOf[S] - val updateTime = inputStream.readLong() - newParentSessionStore.deltaMap.update( - key, StateInfo(state, updateTime, deleted = false)) + obj match { + case marker: LimitMarker => + parentSessionLoopDone = true + val expectedCount = marker.num + assert(expectedCount == newParentSessionStore.deltaMap.size) + case _ => + val key = obj.asInstanceOf[K] + val state = inputStream.readObject().asInstanceOf[S] + val updateTime = inputStream.readLong() + newParentSessionStore.deltaMap.update( + key, StateInfo(state, updateTime, deleted = false)) } } parentStateMap = newParentSessionStore --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org