This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 631e8eb957d [SPARK-42892][SQL] Move sameType and relevant methods out
of DataType
631e8eb957d is described below
commit 631e8eb957de1d86c184182d7dd363c01f15af25
Author: Rui Wang <[email protected]>
AuthorDate: Wed Mar 22 17:03:49 2023 -0400
[SPARK-42892][SQL] Move sameType and relevant methods out of DataType
### What changes were proposed in this pull request?
This PR moves the following methods from `DataType`:
1. equalsIgnoreNullability
2. sameType
3. equalsIgnoreCaseAndNullability
The moved methods are put together into a Util class.
### Why are the changes needed?
To make `DataType` become a simpler interface, non-public methods can be
moved outside of the DataType class.
### Does this PR introduce _any_ user-facing change?
No as the moved methods are private within Spark.
### How was this patch tested?
Existing UT.
Closes #40512 from amaliujia/catalyst_refactor_1.
Authored-by: Rui Wang <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../apache/spark/sql/kafka010/KafkaWriter.scala | 3 +-
.../scala/org/apache/spark/ml/feature/NGram.scala | 3 +-
.../apache/spark/ml/feature/StopWordsRemover.scala | 3 +-
.../spark/ml/source/libsvm/LibSVMRelation.scala | 5 +-
.../apache/spark/mllib/util/modelSaveLoad.scala | 3 +-
.../org/apache/spark/ml/feature/LSHTest.scala | 11 ++--
.../catalyst/analysis/ResolveInlineTables.scala | 3 +-
.../catalyst/analysis/TableOutputResolver.scala | 3 +-
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++-
.../spark/sql/catalyst/catalog/interface.scala | 5 +-
.../apache/spark/sql/catalyst/dsl/package.scala | 3 +-
.../sql/catalyst/expressions/Expression.scala | 3 +-
.../sql/catalyst/expressions/SchemaPruning.scala | 3 +-
.../expressions/collectionOperations.scala | 23 +++----
.../spark/sql/catalyst/expressions/hash.scala | 3 +-
.../expressions/higherOrderFunctions.scala | 3 +-
.../spark/sql/catalyst/optimizer/Optimizer.scala | 4 +-
.../optimizer/UnwrapCastInBinaryComparison.scala | 3 +-
.../spark/sql/catalyst/optimizer/objects.scala | 6 +-
.../sql/catalyst/plans/logical/LogicalPlan.scala | 5 +-
.../plans/logical/basicLogicalOperators.scala | 9 ++-
.../spark/sql/catalyst/types/DataTypeUtils.scala | 77 ++++++++++++++++++++++
.../sql/internal/connector/PredicateUtils.scala | 6 +-
.../org/apache/spark/sql/types/DataType.scala | 60 +----------------
.../org/apache/spark/sql/types/StructType.scala | 4 +-
.../catalyst/analysis/AnsiTypeCoercionSuite.scala | 3 +-
.../sql/catalyst/analysis/TypeCoercionSuite.scala | 3 +-
.../catalyst/expressions/SelectedFieldSuite.scala | 3 +-
.../optimizer/ObjectSerializerPruningSuite.scala | 6 +-
.../org/apache/spark/sql/types/DataTypeSuite.scala | 3 +-
.../apache/spark/sql/types/StructTypeSuite.scala | 23 +++----
.../datasources/parquet/ParquetColumnVector.java | 3 +-
.../datasources/binaryfile/BinaryFileFormat.scala | 3 +-
.../spark/sql/execution/joins/HashJoin.scala | 3 +-
.../ApplyInPandasWithStatePythonRunner.scala | 3 +-
.../streaming/StreamingSymmetricHashJoinExec.scala | 3 +-
.../spark/sql/FileBasedDataSourceSuite.scala | 5 +-
.../connector/GroupBasedDeleteFromTableSuite.scala | 6 +-
.../execution/datasources/SchemaPruningSuite.scala | 3 +-
.../parquet/ParquetPartitionDiscoverySuite.scala | 3 +-
.../spark/sql/hive/HiveExternalCatalog.scala | 3 +-
.../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +-
.../spark/sql/hive/MetastoreDataSourcesSuite.scala | 7 +-
.../sql/hive/execution/HiveSQLViewSuite.scala | 3 +-
44 files changed, 210 insertions(+), 138 deletions(-)
diff --git
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
index 5ef4b3a1c19..92c51416f48 100644
---
a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
+++
b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala
@@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType,
StringType}
import org.apache.spark.util.Utils
@@ -115,7 +116,7 @@ private[kafka010] object KafkaWriter extends Logging {
desired: Seq[DataType])(
default: => Expression): Expression = {
val expr = schema.find(_.name == attrName).getOrElse(default)
- if (!desired.exists(_.sameType(expr.dataType))) {
+ if (!desired.exists(e => DataTypeUtils.sameType(e, expr.dataType))) {
throw new IllegalStateException(s"$attrName attribute unsupported type "
+
s"${expr.dataType.catalogString}. $attrName must be a(n) " +
s"${desired.map(_.catalogString).mkString(" or ")}")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
index fd6fde0744d..d72fb6ecc76 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
@@ -21,6 +21,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
/**
@@ -64,7 +65,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val
uid: String)
}
override protected def validateInputType(inputType: DataType): Unit = {
- require(inputType.sameType(ArrayType(StringType)),
+ require(DataTypeUtils.sameType(inputType, ArrayType(StringType)),
s"Input type must be ${ArrayType(StringType).catalogString} but got " +
inputType.catalogString)
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 8bcd7909b60..056baa1b6bf 100755
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -25,6 +25,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols,
HasOutputCol, HasOutputCols}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{ArrayType, StringType, StructField,
StructType}
@@ -191,7 +192,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0")
override val uid: String
require(!schema.fieldNames.contains(outputColName),
s"Output Column $outputColName already exists.")
val inputType = schema(inputColName).dataType
- require(inputType.sameType(ArrayType(StringType)), "Input type must be "
+
+ require(DataTypeUtils.sameType(inputType, ArrayType(StringType)), "Input
type must be " +
s"${ArrayType(StringType).catalogString} but got
${inputType.catalogString}.")
StructField(outputColName, inputType, schema(inputColName).nullable)
}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
index 6cd635f9cd9..f4c5e3eece2 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -80,8 +81,8 @@ private[libsvm] class LibSVMFileFormat
private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit
= {
if (
dataSchema.size != 2 ||
- !dataSchema(0).dataType.sameType(DataTypes.DoubleType) ||
- !dataSchema(1).dataType.sameType(new VectorUDT()) ||
+ !DataTypeUtils.sameType(dataSchema(0).dataType, DataTypes.DoubleType)
||
+ !DataTypeUtils.sameType(dataSchema(1).dataType, new VectorUDT()) ||
!(forWriting ||
dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0)
) {
throw new IOException(s"Illegal schema for libsvm data,
schema=$dataSchema")
diff --git
a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
index 8f2d8b9014c..c13bc4099ce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
@@ -105,7 +106,7 @@ private[mllib] object Loader {
assert(loadedFields.contains(field.name), s"Unable to parse model data."
+
s" Expected field with name ${field.name} was missing in loaded
schema:" +
s" ${loadedFields.mkString(", ")}")
- assert(loadedFields(field.name).sameType(field.dataType),
+ assert(DataTypeUtils.sameType(loadedFields(field.name), field.dataType),
s"Unable to parse model data. Expected field $field but found field" +
s" with different type: ${loadedFields(field.name)}")
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
index 2815adb75ad..e5e891d5b8f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -22,6 +22,7 @@ import org.scalatest.Assertions._
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils}
import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DataTypes
@@ -116,7 +117,7 @@ private[ml] object LSHTest {
// Compute actual
val actual = model.approxNearestNeighbors(dataset, key, k, singleProbe,
"distCol")
- assert(actual.schema.sameType(model
+ assert(DataTypeUtils.sameType(actual.schema, model
.transformSchema(dataset.schema)
.add("distCol", DataTypes.DoubleType))
)
@@ -156,10 +157,10 @@ private[ml] object LSHTest {
val actual = model.approxSimilarityJoin(datasetA, datasetB, threshold)
SchemaUtils.checkColumnType(actual.schema, "distCol", DataTypes.DoubleType)
- assert(actual.schema.apply("datasetA").dataType
- .sameType(model.transformSchema(datasetA.schema)))
- assert(actual.schema.apply("datasetB").dataType
- .sameType(model.transformSchema(datasetB.schema)))
+ assert(DataTypeUtils.sameType(actual.schema.apply("datasetA").dataType,
+ model.transformSchema(datasetA.schema)))
+ assert(DataTypeUtils.sameType(actual.schema.apply("datasetB").dataType,
+ model.transformSchema(datasetB.schema)))
// Compute precision and recall
val correctCount = actual.filter(col("distCol") <
threshold).count().toDouble
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
index 48c0b83d240..3952ef71b64 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.AliasHelper
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types.{StructField, StructType}
/**
@@ -106,7 +107,7 @@ object ResolveInlineTables extends Rule[LogicalPlan] with
CastSupport with Alias
InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) =>
val targetType = fields(ci).dataType
try {
- val castedExpr = if (e.dataType.sameType(targetType)) {
+ val castedExpr = if (DataTypeUtils.sameType(e.dataType, targetType))
{
e
} else {
cast(e, targetType)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
index e1ee0defa23..71f4eb2918c 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -266,7 +267,7 @@ object TableOutputResolver {
tableAttr.dataType
}
val storeAssignmentPolicy = conf.storeAssignmentPolicy
- lazy val outputField = if (tableAttr.dataType.sameType(queryExpr.dataType)
&&
+ lazy val outputField = if (DataTypeUtils.sameType(tableAttr.dataType,
queryExpr.dataType) &&
tableAttr.name == queryExpr.name &&
tableAttr.metadata == queryExpr.metadata) {
Some(queryExpr)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index e57d1075d2f..059c36c4f90 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -154,12 +155,12 @@ abstract class TypeCoercionBase {
true
} else {
val head = types.head
- types.tail.forall(_.sameType(head))
+ types.tail.forall(e => DataTypeUtils.sameType(e, head))
}
}
protected def castIfNotSameType(expr: Expression, dt: DataType): Expression
= {
- if (!expr.dataType.sameType(dt)) {
+ if (!DataTypeUtils.sameType(expr.dataType, dt)) {
Cast(expr, dt)
} else {
expr
@@ -622,7 +623,8 @@ abstract class TypeCoercionBase {
override val transform: PartialFunction[Expression, Expression] = {
// Lambda function isn't resolved when the rule is executed.
case m @ MapZipWith(left, right, function) if m.arguments.forall(a =>
a.resolved &&
- MapType.acceptsType(a.dataType)) &&
!m.leftKeyType.sameType(m.rightKeyType) =>
+ MapType.acceptsType(a.dataType)) &&
+ !DataTypeUtils.sameType(m.leftKeyType, m.rightKeyType) =>
findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match {
case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType,
finalKeyType) &&
!Cast.forceNullable(m.rightKeyType, finalKeyType) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 08dd2dfd5bc..8bbf1f4f564 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -35,6 +35,7 @@ import
org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap,
AttributeReference, Cast, ExprId, Literal}
import org.apache.spark.sql.catalyst.plans.logical._
import
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -824,8 +825,8 @@ case class HiveTableRelation(
@transient prunedPartitions: Option[Seq[CatalogTablePartition]] = None)
extends LeafNode with MultiInstanceRelation {
assert(tableMeta.identifier.database.isDefined)
- assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType))
- assert(tableMeta.dataSchema.sameType(dataCols.toStructType))
+ assert(DataTypeUtils.sameType(tableMeta.partitionSchema,
partitionCols.toStructType))
+ assert(DataTypeUtils.sameType(tableMeta.dataSchema, dataCols.toStructType))
// The partition column should always appear after data columns.
override def output: Seq[AttributeReference] = dataCols ++ partitionCols
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index c6cc863108d..ac439203cb7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -129,7 +130,7 @@ package object dsl {
UnresolvedExtractValue(expr, Literal(fieldName))
def cast(to: DataType): Expression = {
- if (expr.resolved && expr.dataType.sameType(to)) {
+ if (expr.resolved && DataTypeUtils.sameType(expr.dataType, to)) {
expr
} else {
val cast = Cast(expr, to)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 2d2236a8a80..92bffc2be2e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin,
LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE,
TreePattern}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
@@ -771,7 +772,7 @@ abstract class BinaryOperator extends BinaryExpression with
ExpectsInputTypes wi
override def checkInputDataTypes(): TypeCheckResult = {
// First check whether left and right have the same type, then check if
the type is acceptable.
- if (!left.dataType.sameType(right.dataType)) {
+ if (!DataTypeUtils.sameType(left.dataType, right.dataType)) {
DataTypeMismatch(
errorSubClass = "BINARY_OP_DIFF_TYPES",
messageParameters = Map(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
index e14bcba0ace..dd2d6c2cb61 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala
@@ -22,6 +22,7 @@ import java.util.Locale
import scala.collection.immutable.HashMap
import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types._
object SchemaPruning extends SQLConfHelper {
@@ -126,7 +127,7 @@ object SchemaPruning extends SQLConfHelper {
val rootFieldType = StructType(Array(root.field))
val optFieldType = StructType(Array(opt.field))
val merged = optFieldType.merge(rootFieldType)
- merged.sameType(optFieldType)
+ DataTypeUtils.sameType(merged, optFieldType)
}
}
} ++ rootFields
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 2ccb3a6d0cd..adeccb3ec7e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext,
UnaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT,
TreePattern}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
@@ -67,7 +68,7 @@ trait BinaryArrayExpressionWithImplicitCast
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
- case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
+ case (ArrayType(e1, _), ArrayType(e2, _)) if DataTypeUtils.sameType(e1,
e2) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
DataTypeMismatch(
@@ -245,7 +246,7 @@ case class MapContainsKey(left: Expression, right:
Expression)
DataTypeMismatch(
errorSubClass = "NULL_TYPE",
Map("functionName" -> toSQLId(prettyName)))
- case (MapType(kt, _, _), dt) if kt.sameType(dt) =>
+ case (MapType(kt, _, _), dt) if DataTypeUtils.sameType(kt, dt) =>
TypeUtils.checkForOrderingExpr(kt, prettyName)
case _ =>
DataTypeMismatch(
@@ -1327,7 +1328,7 @@ case class ArrayContains(left: Expression, right:
Expression)
"inputSql" -> toSQLExpr(left),
"inputType" -> toSQLType(left.dataType))
)
- case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeUtils.checkForOrderingExpr(e2, prettyName)
case _ =>
DataTypeMismatch(
@@ -1512,7 +1513,7 @@ case class ArrayPrepend(left: Expression, right:
Expression)
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
- case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeCheckResult.TypeCheckSuccess
case (ArrayType(e1, _), e2) => DataTypeMismatch(
errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
messageParameters = Map(
@@ -2266,7 +2267,7 @@ case class ArrayPosition(left: Expression, right:
Expression)
"inputSql" -> toSQLExpr(left),
"inputType" -> toSQLType(left.dataType))
)
- case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeUtils.checkForOrderingExpr(e2, prettyName)
case _ =>
DataTypeMismatch(
@@ -2426,7 +2427,7 @@ case class ElementAt(
"inputSql" -> toSQLExpr(right),
"inputType" -> toSQLType(right.dataType))
)
- case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
+ case (MapType(e1, _, _), e2) if (!DataTypeUtils.sameType(e2, e1)) =>
DataTypeMismatch(
errorSubClass = "MAP_FUNCTION_DIFF_TYPES",
messageParameters = Map(
@@ -3026,7 +3027,7 @@ case class Sequence(
val startType = start.dataType
def stepType = stepOpt.get.dataType
val typesCorrect =
- startType.sameType(stop.dataType) &&
+ DataTypeUtils.sameType(startType, stop.dataType) &&
(startType match {
case TimestampType | TimestampNTZType =>
stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) ||
@@ -3037,7 +3038,7 @@ case class Sequence(
YearMonthIntervalType.acceptsType(stepType) ||
DayTimeIntervalType.acceptsType(stepType)
case _: IntegralType =>
- stepOpt.isEmpty || stepType.sameType(startType)
+ stepOpt.isEmpty || DataTypeUtils.sameType(stepType, startType)
case _ => false
})
@@ -3715,7 +3716,7 @@ case class ArrayRemove(left: Expression, right:
Expression)
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
- case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeUtils.checkForOrderingExpr(e2, prettyName)
case _ =>
DataTypeMismatch(
@@ -4791,7 +4792,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr:
Expression, itemExpr:
"inputSql" -> toSQLExpr(second),
"inputType" -> toSQLType(second.dataType))
)
- case (ArrayType(e1, _), e2, e3) if e1.sameType(e3) =>
+ case (ArrayType(e1, _), e2, e3) if DataTypeUtils.sameType(e1, e3) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
DataTypeMismatch(
@@ -5078,7 +5079,7 @@ case class ArrayAppend(left: Expression, right:
Expression)
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
- case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
+ case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) =>
TypeCheckResult.TypeCheckSuccess
case (ArrayType(e1, _), e2) => DataTypeMismatch(
errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES",
messageParameters = Map(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 8ac879c73ae..3e1248482cf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -32,6 +32,7 @@ import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.errors.QueryCompilationErrors
@@ -308,7 +309,7 @@ abstract class HashExpression[E] extends Expression {
}
val hashResultType = CodeGenerator.javaType(dataType)
- val typedSeed = if (dataType.sameType(LongType)) s"${seed}L" else s"$seed"
+ val typedSeed = if (DataTypeUtils.sameType(dataType, LongType))
s"${seed}L" else s"$seed"
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = childrenHash,
funcName = "computeHash",
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index c2db38bae45..fec1df108bc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike,
TernaryLike}
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
@@ -1059,7 +1060,7 @@ case class MapZipWith(left: Expression, right:
Expression, function: Expression)
override def checkArgumentDataTypes(): TypeCheckResult = {
super.checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
- if (leftKeyType.sameType(rightKeyType)) {
+ if (DataTypeUtils.sameType(leftKeyType, rightKeyType)) {
TypeUtils.checkForOrderingExpr(leftKeyType, prettyName)
} else {
DataTypeMismatch(
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 3e8571f3eb6..ffc7284e40a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -29,6 +29,7 @@ import
org.apache.spark.sql.catalyst.plans.logical.{RepartitionOperation, _}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -1621,7 +1622,8 @@ object EliminateSorts extends Rule[LogicalPlan] {
// Arithmetic operations for floating-point values are order-sensitive
// (they are not associative).
case _: Sum | _: Average | _: CentralMomentAgg =>
- !Seq(FloatType,
DoubleType).exists(_.sameType(func.children.head.dataType))
+ !Seq(FloatType, DoubleType)
+ .exists(e => DataTypeUtils.sameType(e, func.children.head.dataType))
case _ => false
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
index 4c41fb0fa31..50c5bdc7b8d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala
@@ -24,6 +24,7 @@ import
org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN,
INSET}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types._
/**
@@ -380,7 +381,7 @@ object UnwrapCastInBinaryComparison extends
Rule[LogicalPlan] {
fromExp: Expression,
toType: DataType,
literalType: DataType): Boolean = {
- toType.sameType(literalType) &&
+ DataTypeUtils.sameType(toType, literalType) &&
!fromExp.foldable &&
toType.isInstanceOf[NumericType] &&
canUnwrapCast(fromExp.dataType, toType)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 3387bb20077..80ec5e40bec 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType,
UserDefinedType}
/*
@@ -165,7 +166,8 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
* Note: we should do `transformUp` explicitly to change data types.
*/
private def alignNullTypeInIf(expr: Expression) = expr.transformUp {
- case i @ If(IsNullCondition(), Literal(null, dt), ser) if
!dt.sameType(ser.dataType) =>
+ case i @ If(IsNullCondition(), Literal(null, dt), ser)
+ if !DataTypeUtils.sameType(dt, ser.dataType) =>
i.copy(trueValue = Literal(null, ser.dataType))
}
@@ -204,7 +206,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
val transformedSerializer = serializer.transformDown(transformer)
val prunedSerializer =
alignNullTypeInIf(transformedSerializer).asInstanceOf[NamedExpression]
- if (prunedSerializer.dataType.sameType(prunedDataType)) {
+ if (DataTypeUtils.sameType(prunedSerializer.dataType, prunedDataType)) {
prunedSerializer
} else {
serializer
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 36187bb2d55..e1d46ff73f1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -23,9 +23,10 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering,
QueryPlan}
import
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats
import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag,
UnaryLike}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.MetadataColumnHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.StructType
abstract class LogicalPlan
@@ -314,7 +315,7 @@ object LogicalPlanIntegrity {
Some("Special expressions are placed in the wrong plan: " +
currentPlan.treeString)
} else {
LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan).orElse {
- if (!DataType.equalsIgnoreNullability(previousPlan.schema,
currentPlan.schema)) {
+ if (!DataTypeUtils.equalsIgnoreNullability(previousPlan.schema,
currentPlan.schema)) {
Some(s"The plan output schema has changed from
${previousPlan.schema.sql} to " +
currentPlan.schema.sql + s". The previous plan:
${previousPlan.treeString}\nThe new " +
"plan:\n" + currentPlan.treeString)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index cdf48dd265f..1eddd7ed24d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning,
Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.catalyst.trees.TreePattern._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
@@ -330,7 +331,7 @@ abstract class SetOperation(left: LogicalPlan, right:
LogicalPlan) extends Binar
childrenResolved &&
left.output.length == right.output.length &&
left.output.zip(right.output).forall { case (l, r) =>
- l.dataType.sameType(r.dataType)
+ DataTypeUtils.sameType(l.dataType, r.dataType)
} && duplicateResolved
}
@@ -1280,7 +1281,9 @@ object Expand {
} :+ {
val bitMask = buildBitmask(groupingSetAttrs, attrMap)
val dataType = GroupingID.dataType
- Literal.create(if (dataType.sameType(IntegerType)) bitMask.toInt else
bitMask, dataType)
+ Literal.create(
+ if (DataTypeUtils.sameType(dataType, IntegerType)) bitMask.toInt
+ else bitMask, dataType)
}
if (hasDuplicateGroupingSets) {
@@ -1502,7 +1505,7 @@ case class Unpivot(
def valuesTypeCoercioned: Boolean = canBeCoercioned &&
// all inner values at position idx must have the same data type
values.get.head.zipWithIndex.forall { case (v, idx) =>
- values.get.tail.forall(vals => vals(idx).dataType.sameType(v.dataType))
+ values.get.tail.forall(vals =>
DataTypeUtils.sameType(vals(idx).dataType, v.dataType))
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
new file mode 100644
index 00000000000..c27a17b6dd9
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.types
+
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+
+object DataTypeUtils {
+ /**
+ * Check if `this` and `other` are the same data type when ignoring
nullability
+ * (`StructField.nullable`, `ArrayType.containsNull`, and
`MapType.valueContainsNull`).
+ */
+ def sameType(left: DataType, right: DataType): Boolean =
+ if (SQLConf.get.caseSensitiveAnalysis) {
+ equalsIgnoreNullability(left, right)
+ } else {
+ equalsIgnoreCaseAndNullability(left, right)
+ }
+
+ /**
+ * Compares two types, ignoring nullability of ArrayType, MapType,
StructType.
+ */
+ def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
+ (left, right) match {
+ case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
+ equalsIgnoreNullability(leftElementType, rightElementType)
+ case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType,
rightValueType, _)) =>
+ equalsIgnoreNullability(leftKeyType, rightKeyType) &&
+ equalsIgnoreNullability(leftValueType, rightValueType)
+ case (StructType(leftFields), StructType(rightFields)) =>
+ leftFields.length == rightFields.length &&
+ leftFields.zip(rightFields).forall { case (l, r) =>
+ l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
+ }
+ case (l, r) => l == r
+ }
+ }
+
+ /**
+ * Compares two types, ignoring nullability of ArrayType, MapType,
StructType, and ignoring case
+ * sensitivity of field names in StructType.
+ */
+ def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = {
+ (from, to) match {
+ case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
+ equalsIgnoreCaseAndNullability(fromElement, toElement)
+
+ case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
+ equalsIgnoreCaseAndNullability(fromKey, toKey) &&
+ equalsIgnoreCaseAndNullability(fromValue, toValue)
+
+ case (StructType(fromFields), StructType(toFields)) =>
+ fromFields.length == toFields.length &&
+ fromFields.zip(toFields).forall { case (l, r) =>
+ l.name.equalsIgnoreCase(r.name) &&
+ equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
+ }
+
+ case (fromDataType, toDataType) => fromDataType == toDataType
+ }
+ }
+}
+
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
index a08223e2159..339febb374e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.internal.connector
import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.connector.expressions.{LiteralValue,
NamedReference}
import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not =>
V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And,
EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull,
IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith,
StringStartsWith}
@@ -44,7 +45,8 @@ private[sql] object PredicateUtils {
if (values.length > 0) {
if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None
val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType
- if
(!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) {
+ if (!values.forall(e =>
+ DataTypeUtils.sameType(e.asInstanceOf[LiteralValue[_]].dataType,
dataType))) {
return None
}
val inValues = values.map(v =>
@@ -80,7 +82,7 @@ private[sql] object PredicateUtils {
case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if
isValidBinaryPredicate() =>
val attribute = predicate.children()(0).toString
val value = predicate.children()(1).asInstanceOf[LiteralValue[_]]
- if (!value.dataType.sameType(StringType)) return None
+ if (!DataTypeUtils.sameType(value.dataType, StringType)) return None
val v1Value = value.value.toString
val v1Filter = predicate.name() match {
case "STARTS_WITH" =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 13a7b03bc61..5e2974c2645 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.types._
import
org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer,
DataTypeJsonSerializer}
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI,
STRICT}
import org.apache.spark.sql.types.DayTimeIntervalType._
@@ -93,17 +92,6 @@ abstract class DataType extends AbstractDataType {
def sql: String = simpleString.toUpperCase(Locale.ROOT)
- /**
- * Check if `this` and `other` are the same data type when ignoring
nullability
- * (`StructField.nullable`, `ArrayType.containsNull`, and
`MapType.valueContainsNull`).
- */
- private[spark] def sameType(other: DataType): Boolean =
- if (SQLConf.get.caseSensitiveAnalysis) {
- DataType.equalsIgnoreNullability(this, other)
- } else {
- DataType.equalsIgnoreCaseAndNullability(this, other)
- }
-
/**
* Returns the same data type but set all nullability fields are true
* (`StructField.nullable`, `ArrayType.containsNull`, and
`MapType.valueContainsNull`).
@@ -117,7 +105,8 @@ abstract class DataType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = this
- override private[sql] def acceptsType(other: DataType): Boolean =
sameType(other)
+ override private[sql] def acceptsType(other: DataType): Boolean =
+ DataTypeUtils.sameType(this, other)
private[sql] def physicalDataType: PhysicalDataType =
UninitializedPhysicalType
}
@@ -294,25 +283,6 @@ object DataType {
}
}
- /**
- * Compares two types, ignoring nullability of ArrayType, MapType,
StructType.
- */
- private[sql] def equalsIgnoreNullability(left: DataType, right: DataType):
Boolean = {
- (left, right) match {
- case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
- equalsIgnoreNullability(leftElementType, rightElementType)
- case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType,
rightValueType, _)) =>
- equalsIgnoreNullability(leftKeyType, rightKeyType) &&
- equalsIgnoreNullability(leftValueType, rightValueType)
- case (StructType(leftFields), StructType(rightFields)) =>
- leftFields.length == rightFields.length &&
- leftFields.zip(rightFields).forall { case (l, r) =>
- l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType)
- }
- case (l, r) => l == r
- }
- }
-
/**
* Compares two types, ignoring compatible nullability of ArrayType,
MapType, StructType.
*
@@ -377,30 +347,6 @@ object DataType {
}
}
- /**
- * Compares two types, ignoring nullability of ArrayType, MapType,
StructType, and ignoring case
- * sensitivity of field names in StructType.
- */
- private[sql] def equalsIgnoreCaseAndNullability(from: DataType, to:
DataType): Boolean = {
- (from, to) match {
- case (ArrayType(fromElement, _), ArrayType(toElement, _)) =>
- equalsIgnoreCaseAndNullability(fromElement, toElement)
-
- case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
- equalsIgnoreCaseAndNullability(fromKey, toKey) &&
- equalsIgnoreCaseAndNullability(fromValue, toValue)
-
- case (StructType(fromFields), StructType(toFields)) =>
- fromFields.length == toFields.length &&
- fromFields.zip(toFields).forall { case (l, r) =>
- l.name.equalsIgnoreCase(r.name) &&
- equalsIgnoreCaseAndNullability(l.dataType, r.dataType)
- }
-
- case (fromDataType, toDataType) => fromDataType == toDataType
- }
- }
-
/**
* Returns true if the two data types share the same "shape", i.e. the types
* are the same, but the field names don't need to be the same.
@@ -584,7 +530,7 @@ object DataType {
true
}
- case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] =>
+ case (w, r) if DataTypeUtils.sameType(w, r) && !w.isInstanceOf[NullType]
=>
true
case (w, r) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 9ef3c4d60fd..323f7c6df32 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, InterpretedOrdering}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
LegacyTypeStringParser}
import org.apache.spark.sql.catalyst.trees.Origin
-import org.apache.spark.sql.catalyst.types.{PhysicalDataType,
PhysicalStructType}
+import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType,
PhysicalStructType}
import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils}
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
@@ -704,7 +704,7 @@ object StructType extends AbstractDataType {
// Found a missing field in `source`.
newFields += field
} else if (bothStructType(found.get.dataType, field.dataType) &&
- !found.get.dataType.sameType(field.dataType)) {
+ !DataTypeUtils.sameType(found.get.dataType, field.dataType)) {
// Found a field with same name, but different data type.
findMissingFields(found.get.dataType.asInstanceOf[StructType],
field.dataType.asInstanceOf[StructType], resolver).map { missingType
=>
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
index afbc2fdb5a0..cc0f6046de0 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -65,7 +66,7 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase {
private def shouldCastStringLiteral(to: AbstractDataType, expected:
DataType): Unit = {
val input = Literal("123")
val castResult = AnsiTypeCoercion.implicitCast(input, to)
- assert(DataType.equalsIgnoreCaseAndNullability(
+ assert(DataTypeUtils.equalsIgnoreCaseAndNullability(
castResult.map(_.dataType).orNull, expected),
s"Failed to cast String literal to $to")
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
index e30cce23136..1e8286dec79 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -53,7 +54,7 @@ abstract class TypeCoercionSuiteBase extends AnalysisTest {
// Check null value
val castNull = implicitCast(createNull(from), to)
- assert(DataType.equalsIgnoreCaseAndNullability(
+ assert(DataTypeUtils.equalsIgnoreCaseAndNullability(
castNull.map(_.dataType).orNull, expected),
s"Failed to cast $from to $to")
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
index 3724f313ca6..e09ae776d1c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.AnalysisTest
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.types._
class SelectedFieldSuite extends AnalysisTest {
@@ -534,7 +535,7 @@ class SelectedFieldSuite extends AnalysisTest {
indent("but it actually selected\n") +
indent(StructType(actual :: Nil).treeString) +
indent("Note that expected.dataType.sameType(actual.dataType) = " +
- expected.dataType.sameType(actual.dataType)))
+ DataTypeUtils.sameType(expected.dataType, actual.dataType)))
throw ex
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
index dfe190e6ddc..3dd58dc9fc1 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
@@ -28,6 +28,7 @@ import
org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -96,7 +97,8 @@ class ObjectSerializerPruningSuite extends PlanTest {
CreateNamedStruct(children.take(2))
}.transformUp {
// Aligns null literal in `If` expression to make it resolvable.
- case i @ If(_: IsNull, Literal(null, dt), ser) if
!dt.sameType(ser.dataType) =>
+ case i @ If(_: IsNull, Literal(null, dt), ser)
+ if !DataTypeUtils.sameType(dt, ser.dataType) =>
i.copy(trueValue = Literal(null, ser.dataType))
}.asInstanceOf[NamedExpression]
@@ -126,7 +128,7 @@ class ObjectSerializerPruningSuite extends PlanTest {
}.transformUp {
// Aligns null literal in `If` expression to make it resolvable.
case i @ If(invoke: Invoke, Literal(null, dt), ser) if
invoke.functionName == "isNullAt" &&
- !dt.sameType(ser.dataType) =>
+ !DataTypeUtils.sameType(dt, ser.dataType) =>
i.copy(trueValue = Literal(null, ser.dataType))
}.asInstanceOf[NamedExpression]
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index 2a487133b48..4001b546566 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -22,6 +22,7 @@ import com.fasterxml.jackson.core.JsonParseException
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution,
caseSensitiveResolution}
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat
import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes,
yearMonthIntervalTypes}
@@ -207,7 +208,7 @@ class DataTypeSuite extends SparkFunSuite {
test(s"from DDL - $dataType") {
val parsed = StructType.fromDDL(s"a ${dataType.sql}")
val expected = new StructType().add("a", dataType)
- assert(parsed.sameType(expected))
+ assert(DataTypeUtils.sameType(parsed, expected))
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
index d9eb0892d13..e252594540b 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution,
caseSensitiveResolution}
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DayTimeIntervalType => DT}
@@ -166,22 +167,22 @@ class StructTypeSuite extends SparkFunSuite with
SQLHelper {
val source1 = StructType.fromDDL("c1 INT")
val missing1 = StructType.fromDDL("c2 STRUCT<c3: INT, c4: STRUCT<c5: INT,
c6: INT>>")
assert(StructType.findMissingFields(source1, schema, resolver)
- .exists(_.sameType(missing1)))
+ .exists(e => DataTypeUtils.sameType(e, missing1)))
val source2 = StructType.fromDDL("c1 INT, c3 STRING")
val missing2 = StructType.fromDDL("c2 STRUCT<c3: INT, c4: STRUCT<c5: INT,
c6: INT>>")
assert(StructType.findMissingFields(source2, schema, resolver)
- .exists(_.sameType(missing2)))
+ .exists(e => DataTypeUtils.sameType(e, missing2)))
val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT>")
val missing3 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT, c6:
INT>>")
assert(StructType.findMissingFields(source3, schema, resolver)
- .exists(_.sameType(missing3)))
+ .exists(e => DataTypeUtils.sameType(e, missing3)))
val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4:
STRUCT<c6: INT>>")
val missing4 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT>>")
assert(StructType.findMissingFields(source4, schema, resolver)
- .exists(_.sameType(missing4)))
+ .exists(e => DataTypeUtils.sameType(e, missing4)))
}
test("find missing (nested) fields: array and map") {
@@ -192,7 +193,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
val missing5 = StructType.fromDDL("c2 ARRAY<STRUCT<c3: INT, c4: LONG>>")
assert(
StructType.findMissingFields(source5, schemaWithArray, resolver)
- .exists(_.sameType(missing5)))
+ .exists(e => DataTypeUtils.sameType(e, missing5)))
val schemaWithMap1 = StructType.fromDDL(
"c1 INT, c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>, c3 LONG")
@@ -200,7 +201,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
val missing6 = StructType.fromDDL("c2 MAP<STRUCT<c3: INT, c4: LONG>,
STRING>")
assert(
StructType.findMissingFields(source6, schemaWithMap1, resolver)
- .exists(_.sameType(missing6)))
+ .exists(e => DataTypeUtils.sameType(e, missing6)))
val schemaWithMap2 = StructType.fromDDL(
"c1 INT, c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>, c3 STRING")
@@ -208,7 +209,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper {
val missing7 = StructType.fromDDL("c2 MAP<STRING, STRUCT<c3: INT, c4:
LONG>>")
assert(
StructType.findMissingFields(source7, schemaWithMap2, resolver)
- .exists(_.sameType(missing7)))
+ .exists(e => DataTypeUtils.sameType(e, missing7)))
// Unsupported: nested struct in array, map
val source8 = StructType.fromDDL("c1 INT, c2 ARRAY<STRUCT<c3: INT>>")
@@ -232,22 +233,22 @@ class StructTypeSuite extends SparkFunSuite with
SQLHelper {
val source1 = StructType.fromDDL("c1 INT, C2 LONG")
val missing1 = StructType.fromDDL("c2 STRUCT<c3: INT, C4: STRUCT<C5:
INT, c6: INT>>")
assert(StructType.findMissingFields(source1, schema, resolver)
- .exists(_.sameType(missing1)))
+ .exists(e => DataTypeUtils.sameType(e, missing1)))
val source2 = StructType.fromDDL("c2 LONG")
val missing2 = StructType.fromDDL("c1 INT")
assert(StructType.findMissingFields(source2, schema, resolver)
- .exists(_.sameType(missing2)))
+ .exists(e => DataTypeUtils.sameType(e, missing2)))
val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4:
STRUCT<c5: INT>>")
val missing3 = StructType.fromDDL("c2 STRUCT<C4: STRUCT<C5: INT, c6:
INT>>")
assert(StructType.findMissingFields(source3, schema, resolver)
- .exists(_.sameType(missing3)))
+ .exists(e => DataTypeUtils.sameType(e, missing3)))
val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4:
STRUCT<C5: Int>>")
val missing4 = StructType.fromDDL("c2 STRUCT<C4: STRUCT<c6: INT>>")
assert(StructType.findMissingFields(source4, schema, resolver)
- .exists(_.sameType(missing4)))
+ .exists(e => DataTypeUtils.sameType(e, missing4)))
}
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
index 5272151acf2..e8ae6f290bc 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java
@@ -29,6 +29,7 @@ import
org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.ArrayType;
import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.catalyst.types.DataTypeUtils;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.MapType;
import org.apache.spark.sql.types.StructType;
@@ -64,7 +65,7 @@ final class ParquetColumnVector {
boolean isTopLevel,
Object defaultValue) {
DataType sparkType = column.sparkType();
- if (!sparkType.sameType(vector.dataType())) {
+ if (!DataTypeUtils.sameType(sparkType, vector.dataType())) {
throw new IllegalArgumentException("Spark type: " + sparkType +
" doesn't match the type: " + vector.dataType() + " in column vector");
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
index ba6d351761e..cbff526592f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.{FileFormat,
OutputWriterFactory, PartitionedFile}
@@ -88,7 +89,7 @@ class BinaryFileFormat extends FileFormat with
DataSourceRegister {
filters: Seq[Filter],
options: Map[String, String],
hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = {
- require(dataSchema.sameType(schema),
+ require(DataTypeUtils.sameType(dataSchema, schema),
s"""
|Binary file data source expects dataSchema: $schema,
|but got: $dataSchema.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 4595ea049ef..f5e8a36de5a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight,
BuildSide}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils,
RowIterator}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType}
@@ -109,7 +110,7 @@ trait HashJoin extends JoinCodegenSupport {
require(leftKeys.length == rightKeys.length &&
leftKeys.map(_.dataType)
.zip(rightKeys.map(_.dataType))
- .forall(types => types._1.sameType(types._2)),
+ .forall(types => DataTypeUtils.sameType(types._1, types._2)),
"Join keys from two sides should have same length and types")
buildSide match {
case BuildLeft => (leftKeys, rightKeys)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index f3531668c8e..79773a9d534 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType,
OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER}
import
org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA
@@ -154,7 +155,7 @@ class ApplyInPandasWithStatePythonRunner(
// UDF returns a StructType column in ColumnarBatch, select the
children here
val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector]
val dataType = schema(ordinal).dataType.asInstanceOf[StructType]
- assert(dataType.sameType(expectedType),
+ assert(DataTypeUtils.sameType(dataType, expectedType),
s"Schema equality check failure! type from Arrow: $dataType, expected
type: $expectedType")
val outputVectors = dataType.indices.map(structVector.getChild)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 2445dcd519e..3ad1dc58cae 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetric
import
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._
@@ -182,7 +183,7 @@ case class StreamingSymmetricHashJoinExec(
require(leftKeys.length == rightKeys.length &&
leftKeys.map(_.dataType)
.zip(rightKeys.map(_.dataType))
- .forall(types => types._1.sameType(types._2)),
+ .forall(types => DataTypeUtils.sameType(types._1, types._2)),
"Join keys from two sides should have same length and types")
private val storeConf = new StateStoreConf(conf)
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 2796b1cf154..a613fe04f5e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT,
NullData, NullUDT}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
GreaterThan, Literal}
import
org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt,
positiveInt}
import org.apache.spark.sql.catalyst.plans.logical.Filter
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.FilePartition
@@ -88,7 +89,7 @@ class FileBasedDataSourceSuite extends QueryTest
df.write.format(format).option("header", "true").save(dir)
val answerDf = spark.read.format(format).option("header",
"true").load(dir)
- assert(df.schema.sameType(answerDf.schema))
+ assert(DataTypeUtils.sameType(df.schema, answerDf.schema))
checkAnswer(df, answerDf)
}
}
@@ -104,7 +105,7 @@ class FileBasedDataSourceSuite extends QueryTest
emptyDf.write.format(format).save(path)
val df = spark.read.format(format).load(path)
- assert(df.schema.sameType(emptyDf.schema))
+ assert(DataTypeUtils.sameType(df.schema, emptyDf.schema))
checkAnswer(df, emptyDf)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
index 36905027cb0..33c6e4f5be6 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.InSubqueryExec
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.internal.SQLConf
@@ -141,14 +142,15 @@ class GroupBasedDeleteFromTableSuite extends
DeleteFromTableSuiteBase {
val primaryScan = collect(executedPlan) {
case s: BatchScanExec => s
}.head
- assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema)))
+ assert(DataTypeUtils.sameType(primaryScan.schema,
StructType.fromDDL(primaryScanSchema)))
primaryScan.runtimeFilters match {
case Seq(DynamicPruningExpression(child: InSubqueryExec)) =>
val groupFilterScan = collect(child.plan) {
case s: BatchScanExec => s
}.head
-
assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema)))
+ assert(DataTypeUtils.sameType(groupFilterScan.schema,
+ StructType.fromDDL(groupFilterScanSchema)))
case _ =>
fail("could not find group filter scan")
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
index bd9c79e5b96..bf496d6db21 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.SchemaPruningTest
import org.apache.spark.sql.catalyst.expressions.Concat
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.Expand
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
@@ -851,7 +852,7 @@ abstract class SchemaPruningSuite
protected val schemaEquality = new Equality[StructType] {
override def areEqual(a: StructType, b: Any): Boolean =
b match {
- case otherType: StructType => a.sameType(otherType)
+ case otherType: StructType => DataTypeUtils.sameType(a, otherType)
case _ => false
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 43626237b13..183c4f71df6 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.SparkConf
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils,
TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateTimeToMicros
import org.apache.spark.sql.execution.datasources._
@@ -1101,7 +1102,7 @@ abstract class ParquetPartitionDiscoverySuite
val input = spark.read.parquet(path.getAbsolutePath).select("id",
"date_month", "date_hour", "date_t_hour", "data")
- assert(data.schema.sameType(input.schema))
+ assert(DataTypeUtils.sameType(data.schema, input.schema))
checkAnswer(input, data)
}
}
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
index 9af34ba3b20..90e8f9b9d0e 100644
---
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
+++
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala
@@ -41,6 +41,7 @@ import
org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap,
CharVarcharUtils}
import org.apache.spark.sql.execution.command.DDLUtils
import org.apache.spark.sql.execution.datasources.{PartitioningUtils,
SourceOptions}
@@ -815,7 +816,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf,
hadoopConf: Configurat
val partColumnNames = getPartitionColumnsFromTableProperties(table)
val reorderedSchema = reorderSchema(schema = schemaFromTableProps,
partColumnNames)
- if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema,
table.schema) ||
+ if (DataTypeUtils.equalsIgnoreCaseAndNullability(reorderedSchema,
table.schema) ||
options.respectSparkSchema) {
hiveTable.copy(
schema = reorderedSchema,
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 12b570e8186..4d13c02c503 100644
---
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat,
ParquetOptions}
import org.apache.spark.sql.internal.SQLConf
@@ -89,7 +90,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession:
SparkSession) extends Log
// we will use the cached relation.
val useCached =
relation.location.rootPaths.toSet == pathsInMetastore.toSet &&
- logical.schema.sameType(schemaInMetastore) &&
+ DataTypeUtils.sameType(logical.schema, schemaInMetastore) &&
// We don't support hive bucketed tables. This function
`getCached` is only used for
// converting supported Hive tables to data source tables.
relation.bucketSpec.isEmpty &&
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index e76ef4725f7..d1b94815c5d 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -27,6 +27,7 @@ import org.apache.logging.log4j.Level
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat,
CatalogTable, CatalogTableType}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.execution.command.CreateTableCommand
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation,
LogicalRelation}
@@ -776,7 +777,7 @@ class MetastoreDataSourcesSuite extends QueryTest with
SQLTestUtils with TestHiv
})
// Make sure partition columns are correctly stored in metastore.
assert(
- expectedPartitionColumns.sameType(actualPartitionColumns),
+ DataTypeUtils.sameType(expectedPartitionColumns,
actualPartitionColumns),
s"Partitions columns stored in metastore $actualPartitionColumns is
not the " +
s"partition columns defined by the saveAsTable operation
$expectedPartitionColumns.")
@@ -818,7 +819,7 @@ class MetastoreDataSourcesSuite extends QueryTest with
SQLTestUtils with TestHiv
})
// Make sure bucketBy columns are correctly stored in metastore.
assert(
- expectedBucketByColumns.sameType(actualBucketByColumns),
+ DataTypeUtils.sameType(expectedBucketByColumns, actualBucketByColumns),
s"Partitions columns stored in metastore $actualBucketByColumns is not
the " +
s"partition columns defined by the saveAsTable operation
$expectedBucketByColumns.")
@@ -829,7 +830,7 @@ class MetastoreDataSourcesSuite extends QueryTest with
SQLTestUtils with TestHiv
})
// Make sure sortBy columns are correctly stored in metastore.
assert(
- expectedSortByColumns.sameType(actualSortByColumns),
+ DataTypeUtils.sameType(expectedSortByColumns, actualSortByColumns),
s"Partitions columns stored in metastore $actualSortByColumns is not
the " +
s"partition columns defined by the saveAsTable operation
$expectedSortByColumns.")
diff --git
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala
index 8b7f7ade560..c81c4649107 100644
---
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala
+++
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat,
CatalogTable, CatalogTableType, HiveTableRelation}
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME
import org.apache.spark.sql.execution.SQLViewSuite
import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils}
@@ -143,7 +144,7 @@ class HiveSQLViewSuite extends SQLViewSuite with
TestHiveSingleton {
// Check the output rows.
checkAnswer(df, Row(1, 2))
// Check the output schema.
- assert(df.schema.sameType(view.schema))
+ assert(DataTypeUtils.sameType(df.schema, view.schema))
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]