This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 631e8eb957d [SPARK-42892][SQL] Move sameType and relevant methods out of DataType 631e8eb957d is described below commit 631e8eb957de1d86c184182d7dd363c01f15af25 Author: Rui Wang <rui.w...@databricks.com> AuthorDate: Wed Mar 22 17:03:49 2023 -0400 [SPARK-42892][SQL] Move sameType and relevant methods out of DataType ### What changes were proposed in this pull request? This PR moves the following methods from `DataType`: 1. equalsIgnoreNullability 2. sameType 3. equalsIgnoreCaseAndNullability The moved methods are put together into a Util class. ### Why are the changes needed? To make `DataType` become a simpler interface, non-public methods can be moved outside of the DataType class. ### Does this PR introduce _any_ user-facing change? No as the moved methods are private within Spark. ### How was this patch tested? Existing UT. Closes #40512 from amaliujia/catalyst_refactor_1. Authored-by: Rui Wang <rui.w...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../apache/spark/sql/kafka010/KafkaWriter.scala | 3 +- .../scala/org/apache/spark/ml/feature/NGram.scala | 3 +- .../apache/spark/ml/feature/StopWordsRemover.scala | 3 +- .../spark/ml/source/libsvm/LibSVMRelation.scala | 5 +- .../apache/spark/mllib/util/modelSaveLoad.scala | 3 +- .../org/apache/spark/ml/feature/LSHTest.scala | 11 ++-- .../catalyst/analysis/ResolveInlineTables.scala | 3 +- .../catalyst/analysis/TableOutputResolver.scala | 3 +- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 8 ++- .../spark/sql/catalyst/catalog/interface.scala | 5 +- .../apache/spark/sql/catalyst/dsl/package.scala | 3 +- .../sql/catalyst/expressions/Expression.scala | 3 +- .../sql/catalyst/expressions/SchemaPruning.scala | 3 +- .../expressions/collectionOperations.scala | 23 +++---- .../spark/sql/catalyst/expressions/hash.scala | 3 +- .../expressions/higherOrderFunctions.scala | 3 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 4 +- .../optimizer/UnwrapCastInBinaryComparison.scala | 3 +- .../spark/sql/catalyst/optimizer/objects.scala | 6 +- .../sql/catalyst/plans/logical/LogicalPlan.scala | 5 +- .../plans/logical/basicLogicalOperators.scala | 9 ++- .../spark/sql/catalyst/types/DataTypeUtils.scala | 77 ++++++++++++++++++++++ .../sql/internal/connector/PredicateUtils.scala | 6 +- .../org/apache/spark/sql/types/DataType.scala | 60 +---------------- .../org/apache/spark/sql/types/StructType.scala | 4 +- .../catalyst/analysis/AnsiTypeCoercionSuite.scala | 3 +- .../sql/catalyst/analysis/TypeCoercionSuite.scala | 3 +- .../catalyst/expressions/SelectedFieldSuite.scala | 3 +- .../optimizer/ObjectSerializerPruningSuite.scala | 6 +- .../org/apache/spark/sql/types/DataTypeSuite.scala | 3 +- .../apache/spark/sql/types/StructTypeSuite.scala | 23 +++---- .../datasources/parquet/ParquetColumnVector.java | 3 +- .../datasources/binaryfile/BinaryFileFormat.scala | 3 +- .../spark/sql/execution/joins/HashJoin.scala | 3 +- .../ApplyInPandasWithStatePythonRunner.scala | 3 +- .../streaming/StreamingSymmetricHashJoinExec.scala | 3 +- .../spark/sql/FileBasedDataSourceSuite.scala | 5 +- .../connector/GroupBasedDeleteFromTableSuite.scala | 6 +- .../execution/datasources/SchemaPruningSuite.scala | 3 +- .../parquet/ParquetPartitionDiscoverySuite.scala | 3 +- .../spark/sql/hive/HiveExternalCatalog.scala | 3 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 3 +- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 7 +- .../sql/hive/execution/HiveSQLViewSuite.scala | 3 +- 44 files changed, 210 insertions(+), 138 deletions(-) diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 5ef4b3a1c19..92c51416f48 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -23,6 +23,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType} import org.apache.spark.util.Utils @@ -115,7 +116,7 @@ private[kafka010] object KafkaWriter extends Logging { desired: Seq[DataType])( default: => Expression): Expression = { val expr = schema.find(_.name == attrName).getOrElse(default) - if (!desired.exists(_.sameType(expr.dataType))) { + if (!desired.exists(e => DataTypeUtils.sameType(e, expr.dataType))) { throw new IllegalStateException(s"$attrName attribute unsupported type " + s"${expr.dataType.catalogString}. $attrName must be a(n) " + s"${desired.map(_.catalogString).mkString(" or ")}") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index fd6fde0744d..d72fb6ecc76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -21,6 +21,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.{ArrayType, DataType, StringType} /** @@ -64,7 +65,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) } override protected def validateInputType(inputType: DataType): Unit = { - require(inputType.sameType(ArrayType(StringType)), + require(DataTypeUtils.sameType(inputType, ArrayType(StringType)), s"Input type must be ${ArrayType(StringType).catalogString} but got " + inputType.catalogString) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 8bcd7909b60..056baa1b6bf 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -25,6 +25,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{ArrayType, StringType, StructField, StructType} @@ -191,7 +192,7 @@ class StopWordsRemover @Since("1.5.0") (@Since("1.5.0") override val uid: String require(!schema.fieldNames.contains(outputColName), s"Output Column $outputColName already exists.") val inputType = schema(inputColName).dataType - require(inputType.sameType(ArrayType(StringType)), "Input type must be " + + require(DataTypeUtils.sameType(inputType, ArrayType(StringType)), "Input type must be " + s"${ArrayType(StringType).catalogString} but got ${inputType.catalogString}.") StructField(outputColName, inputType, schema(inputColName).nullable) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 6cd635f9cd9..f4c5e3eece2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -80,8 +81,8 @@ private[libsvm] class LibSVMFileFormat private def verifySchema(dataSchema: StructType, forWriting: Boolean): Unit = { if ( dataSchema.size != 2 || - !dataSchema(0).dataType.sameType(DataTypes.DoubleType) || - !dataSchema(1).dataType.sameType(new VectorUDT()) || + !DataTypeUtils.sameType(dataSchema(0).dataType, DataTypes.DoubleType) || + !DataTypeUtils.sameType(dataSchema(1).dataType, new VectorUDT()) || !(forWriting || dataSchema(1).metadata.getLong(LibSVMOptions.NUM_FEATURES).toInt > 0) ) { throw new IOException(s"Illegal schema for libsvm data, schema=$dataSchema") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala index 8f2d8b9014c..c13bc4099ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala @@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} /** @@ -105,7 +106,7 @@ private[mllib] object Loader { assert(loadedFields.contains(field.name), s"Unable to parse model data." + s" Expected field with name ${field.name} was missing in loaded schema:" + s" ${loadedFields.mkString(", ")}") - assert(loadedFields(field.name).sameType(field.dataType), + assert(DataTypeUtils.sameType(loadedFields(field.name), field.dataType), s"Unable to parse model data. Expected field $field but found field" + s" with different type: ${loadedFields(field.name)}") } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala index 2815adb75ad..e5e891d5b8f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala @@ -22,6 +22,7 @@ import org.scalatest.Assertions._ import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.util.{MLTestingUtils, SchemaUtils} import org.apache.spark.sql.Dataset +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DataTypes @@ -116,7 +117,7 @@ private[ml] object LSHTest { // Compute actual val actual = model.approxNearestNeighbors(dataset, key, k, singleProbe, "distCol") - assert(actual.schema.sameType(model + assert(DataTypeUtils.sameType(actual.schema, model .transformSchema(dataset.schema) .add("distCol", DataTypes.DoubleType)) ) @@ -156,10 +157,10 @@ private[ml] object LSHTest { val actual = model.approxSimilarityJoin(datasetA, datasetB, threshold) SchemaUtils.checkColumnType(actual.schema, "distCol", DataTypes.DoubleType) - assert(actual.schema.apply("datasetA").dataType - .sameType(model.transformSchema(datasetA.schema))) - assert(actual.schema.apply("datasetB").dataType - .sameType(model.transformSchema(datasetB.schema))) + assert(DataTypeUtils.sameType(actual.schema.apply("datasetA").dataType, + model.transformSchema(datasetA.schema))) + assert(DataTypeUtils.sameType(actual.schema.apply("datasetB").dataType, + model.transformSchema(datasetB.schema))) // Compute precision and recall val correctCount = actual.filter(col("distCol") < threshold).count().toDouble diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 48c0b83d240..3952ef71b64 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.AliasHelper import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.AlwaysProcess +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.{StructField, StructType} /** @@ -106,7 +107,7 @@ object ResolveInlineTables extends Rule[LogicalPlan] with CastSupport with Alias InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { - val castedExpr = if (e.dataType.sameType(targetType)) { + val castedExpr = if (DataTypeUtils.sameType(e.dataType, targetType)) { e } else { cast(e, targetType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index e1ee0defa23..71f4eb2918c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors @@ -266,7 +267,7 @@ object TableOutputResolver { tableAttr.dataType } val storeAssignmentPolicy = conf.storeAssignmentPolicy - lazy val outputField = if (tableAttr.dataType.sameType(queryExpr.dataType) && + lazy val outputField = if (DataTypeUtils.sameType(tableAttr.dataType, queryExpr.dataType) && tableAttr.name == queryExpr.name && tableAttr.metadata == queryExpr.metadata) { Some(queryExpr) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e57d1075d2f..059c36c4f90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -154,12 +155,12 @@ abstract class TypeCoercionBase { true } else { val head = types.head - types.tail.forall(_.sameType(head)) + types.tail.forall(e => DataTypeUtils.sameType(e, head)) } } protected def castIfNotSameType(expr: Expression, dt: DataType): Expression = { - if (!expr.dataType.sameType(dt)) { + if (!DataTypeUtils.sameType(expr.dataType, dt)) { Cast(expr, dt) } else { expr @@ -622,7 +623,8 @@ abstract class TypeCoercionBase { override val transform: PartialFunction[Expression, Expression] = { // Lambda function isn't resolved when the rule is executed. case m @ MapZipWith(left, right, function) if m.arguments.forall(a => a.resolved && - MapType.acceptsType(a.dataType)) && !m.leftKeyType.sameType(m.rightKeyType) => + MapType.acceptsType(a.dataType)) && + !DataTypeUtils.sameType(m.leftKeyType, m.rightKeyType) => findWiderTypeForTwo(m.leftKeyType, m.rightKeyType) match { case Some(finalKeyType) if !Cast.forceNullable(m.leftKeyType, finalKeyType) && !Cast.forceNullable(m.rightKeyType, finalKeyType) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 08dd2dfd5bc..8bbf1f4f564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Cast, ExprId, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors @@ -824,8 +825,8 @@ case class HiveTableRelation( @transient prunedPartitions: Option[Seq[CatalogTablePartition]] = None) extends LeafNode with MultiInstanceRelation { assert(tableMeta.identifier.database.isDefined) - assert(tableMeta.partitionSchema.sameType(partitionCols.toStructType)) - assert(tableMeta.dataSchema.sameType(dataCols.toStructType)) + assert(DataTypeUtils.sameType(tableMeta.partitionSchema, partitionCols.toStructType)) + assert(DataTypeUtils.sameType(tableMeta.dataSchema, dataCols.toStructType)) // The partition column should always appear after data columns. override def output: Seq[AttributeReference] = dataCols ++ partitionCols diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index c6cc863108d..ac439203cb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -129,7 +130,7 @@ package object dsl { UnresolvedExtractValue(expr, Literal(fieldName)) def cast(to: DataType): Expression = { - if (expr.resolved && expr.dataType.sameType(to)) { + if (expr.resolved && DataTypeUtils.sameType(expr.dataType, to)) { expr } else { val cast = Cast(expr, to) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 2d2236a8a80..92bffc2be2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, CurrentOrigin, LeafLike, QuaternaryLike, SQLQueryContext, TernaryLike, TreeNode, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{RUNTIME_REPLACEABLE, TreePattern} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf @@ -771,7 +772,7 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes wi override def checkInputDataTypes(): TypeCheckResult = { // First check whether left and right have the same type, then check if the type is acceptable. - if (!left.dataType.sameType(right.dataType)) { + if (!DataTypeUtils.sameType(left.dataType, right.dataType)) { DataTypeMismatch( errorSubClass = "BINARY_OP_DIFF_TYPES", messageParameters = Map( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala index e14bcba0ace..dd2d6c2cb61 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SchemaPruning.scala @@ -22,6 +22,7 @@ import java.util.Locale import scala.collection.immutable.HashMap import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ object SchemaPruning extends SQLConfHelper { @@ -126,7 +127,7 @@ object SchemaPruning extends SQLConfHelper { val rootFieldType = StructType(Array(root.field)) val optFieldType = StructType(Array(opt.field)) val merged = optFieldType.merge(rootFieldType) - merged.sameType(optFieldType) + DataTypeUtils.sameType(merged, optFieldType) } } } ++ rootFields diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 2ccb3a6d0cd..adeccb3ec7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, SQLQueryContext, UnaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern.{ARRAYS_ZIP, CONCAT, TreePattern} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.catalyst.util.DateTimeUtils._ @@ -67,7 +68,7 @@ trait BinaryArrayExpressionWithImplicitCast override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) => + case (ArrayType(e1, _), ArrayType(e2, _)) if DataTypeUtils.sameType(e1, e2) => TypeCheckResult.TypeCheckSuccess case _ => DataTypeMismatch( @@ -245,7 +246,7 @@ case class MapContainsKey(left: Expression, right: Expression) DataTypeMismatch( errorSubClass = "NULL_TYPE", Map("functionName" -> toSQLId(prettyName))) - case (MapType(kt, _, _), dt) if kt.sameType(dt) => + case (MapType(kt, _, _), dt) if DataTypeUtils.sameType(kt, dt) => TypeUtils.checkForOrderingExpr(kt, prettyName) case _ => DataTypeMismatch( @@ -1327,7 +1328,7 @@ case class ArrayContains(left: Expression, right: Expression) "inputSql" -> toSQLExpr(left), "inputType" -> toSQLType(left.dataType)) ) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => + case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => DataTypeMismatch( @@ -1512,7 +1513,7 @@ case class ArrayPrepend(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => TypeCheckResult.TypeCheckSuccess case (ArrayType(e1, _), e2) => DataTypeMismatch( errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", messageParameters = Map( @@ -2266,7 +2267,7 @@ case class ArrayPosition(left: Expression, right: Expression) "inputSql" -> toSQLExpr(left), "inputType" -> toSQLType(left.dataType)) ) - case (ArrayType(e1, _), e2) if e1.sameType(e2) => + case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => DataTypeMismatch( @@ -2426,7 +2427,7 @@ case class ElementAt( "inputSql" -> toSQLExpr(right), "inputType" -> toSQLType(right.dataType)) ) - case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) => + case (MapType(e1, _, _), e2) if (!DataTypeUtils.sameType(e2, e1)) => DataTypeMismatch( errorSubClass = "MAP_FUNCTION_DIFF_TYPES", messageParameters = Map( @@ -3026,7 +3027,7 @@ case class Sequence( val startType = start.dataType def stepType = stepOpt.get.dataType val typesCorrect = - startType.sameType(stop.dataType) && + DataTypeUtils.sameType(startType, stop.dataType) && (startType match { case TimestampType | TimestampNTZType => stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType) || @@ -3037,7 +3038,7 @@ case class Sequence( YearMonthIntervalType.acceptsType(stepType) || DayTimeIntervalType.acceptsType(stepType) case _: IntegralType => - stepOpt.isEmpty || stepType.sameType(startType) + stepOpt.isEmpty || DataTypeUtils.sameType(stepType, startType) case _ => false }) @@ -3715,7 +3716,7 @@ case class ArrayRemove(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (ArrayType(e1, _), e2) if e1.sameType(e2) => + case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => TypeUtils.checkForOrderingExpr(e2, prettyName) case _ => DataTypeMismatch( @@ -4791,7 +4792,7 @@ case class ArrayInsert(srcArrayExpr: Expression, posExpr: Expression, itemExpr: "inputSql" -> toSQLExpr(second), "inputType" -> toSQLType(second.dataType)) ) - case (ArrayType(e1, _), e2, e3) if e1.sameType(e3) => + case (ArrayType(e1, _), e2, e3) if DataTypeUtils.sameType(e1, e3) => TypeCheckResult.TypeCheckSuccess case _ => DataTypeMismatch( @@ -5078,7 +5079,7 @@ case class ArrayAppend(left: Expression, right: Expression) override def checkInputDataTypes(): TypeCheckResult = { (left.dataType, right.dataType) match { - case (ArrayType(e1, _), e2) if e1.sameType(e2) => TypeCheckResult.TypeCheckSuccess + case (ArrayType(e1, _), e2) if DataTypeUtils.sameType(e1, e2) => TypeCheckResult.TypeCheckSuccess case (ArrayType(e1, _), e2) => DataTypeMismatch( errorSubClass = "ARRAY_FUNCTION_DIFF_TYPES", messageParameters = Map( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 8ac879c73ae..3e1248482cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.catalyst.util.DateTimeConstants._ import org.apache.spark.sql.errors.QueryCompilationErrors @@ -308,7 +309,7 @@ abstract class HashExpression[E] extends Expression { } val hashResultType = CodeGenerator.javaType(dataType) - val typedSeed = if (dataType.sameType(LongType)) s"${seed}L" else s"$seed" + val typedSeed = if (DataTypeUtils.sameType(dataType, LongType)) s"${seed}L" else s"$seed" val codes = ctx.splitExpressionsWithCurrentInputs( expressions = childrenHash, funcName = "computeHash", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index c2db38bae45..fec1df108bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Cast._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, QuaternaryLike, TernaryLike} import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -1059,7 +1060,7 @@ case class MapZipWith(left: Expression, right: Expression, function: Expression) override def checkArgumentDataTypes(): TypeCheckResult = { super.checkArgumentDataTypes() match { case TypeCheckResult.TypeCheckSuccess => - if (leftKeyType.sameType(rightKeyType)) { + if (DataTypeUtils.sameType(leftKeyType, rightKeyType)) { TypeUtils.checkForOrderingExpr(leftKeyType, prettyName) } else { DataTypeMismatch( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3e8571f3eb6..ffc7284e40a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{RepartitionOperation, _} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -1621,7 +1622,8 @@ object EliminateSorts extends Rule[LogicalPlan] { // Arithmetic operations for floating-point values are order-sensitive // (they are not associative). case _: Sum | _: Average | _: CentralMomentAgg => - !Seq(FloatType, DoubleType).exists(_.sameType(func.children.head.dataType)) + !Seq(FloatType, DoubleType) + .exists(e => DataTypeUtils.sameType(e, func.children.head.dataType)) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala index 4c41fb0fa31..50c5bdc7b8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/UnwrapCastInBinaryComparison.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.FalseLiteral import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN, INSET} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ /** @@ -380,7 +381,7 @@ object UnwrapCastInBinaryComparison extends Rule[LogicalPlan] { fromExp: Expression, toType: DataType, literalType: DataType): Boolean = { - toType.sameType(literalType) && + DataTypeUtils.sameType(toType, literalType) && !fromExp.foldable && toType.isInstanceOf[NumericType] && canUnwrapCast(fromExp.dataType, toType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 3387bb20077..80ec5e40bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType, UserDefinedType} /* @@ -165,7 +166,8 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { * Note: we should do `transformUp` explicitly to change data types. */ private def alignNullTypeInIf(expr: Expression) = expr.transformUp { - case i @ If(IsNullCondition(), Literal(null, dt), ser) if !dt.sameType(ser.dataType) => + case i @ If(IsNullCondition(), Literal(null, dt), ser) + if !DataTypeUtils.sameType(dt, ser.dataType) => i.copy(trueValue = Literal(null, ser.dataType)) } @@ -204,7 +206,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] { val transformedSerializer = serializer.transformDown(transformer) val prunedSerializer = alignNullTypeInIf(transformedSerializer).asInstanceOf[NamedExpression] - if (prunedSerializer.dataType.sameType(prunedDataType)) { + if (DataTypeUtils.sameType(prunedSerializer.dataType, prunedDataType)) { prunedSerializer } else { serializer diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 36187bb2d55..e1d46ff73f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -23,9 +23,10 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{AliasAwareQueryOutputOrdering, QueryPlan} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.LogicalPlanStats import org.apache.spark.sql.catalyst.trees.{BinaryLike, LeafLike, TreeNodeTag, UnaryLike} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.MetadataColumnHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.StructType abstract class LogicalPlan @@ -314,7 +315,7 @@ object LogicalPlanIntegrity { Some("Special expressions are placed in the wrong plan: " + currentPlan.treeString) } else { LogicalPlanIntegrity.validateExprIdUniqueness(currentPlan).orElse { - if (!DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) { + if (!DataTypeUtils.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) { Some(s"The plan output schema has changed from ${previousPlan.schema.sql} to " + currentPlan.schema.sql + s". The previous plan: ${previousPlan.treeString}\nThe new " + "plan:\n" + currentPlan.treeString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index cdf48dd265f..1eddd7ed24d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -330,7 +331,7 @@ abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends Binar childrenResolved && left.output.length == right.output.length && left.output.zip(right.output).forall { case (l, r) => - l.dataType.sameType(r.dataType) + DataTypeUtils.sameType(l.dataType, r.dataType) } && duplicateResolved } @@ -1280,7 +1281,9 @@ object Expand { } :+ { val bitMask = buildBitmask(groupingSetAttrs, attrMap) val dataType = GroupingID.dataType - Literal.create(if (dataType.sameType(IntegerType)) bitMask.toInt else bitMask, dataType) + Literal.create( + if (DataTypeUtils.sameType(dataType, IntegerType)) bitMask.toInt + else bitMask, dataType) } if (hasDuplicateGroupingSets) { @@ -1502,7 +1505,7 @@ case class Unpivot( def valuesTypeCoercioned: Boolean = canBeCoercioned && // all inner values at position idx must have the same data type values.get.head.zipWithIndex.forall { case (v, idx) => - values.get.tail.forall(vals => vals(idx).dataType.sameType(v.dataType)) + values.get.tail.forall(vals => DataTypeUtils.sameType(vals(idx).dataType, v.dataType)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala new file mode 100644 index 00000000000..c27a17b6dd9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.types + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} + +object DataTypeUtils { + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + def sameType(left: DataType, right: DataType): Boolean = + if (SQLConf.get.caseSensitiveAnalysis) { + equalsIgnoreNullability(left, right) + } else { + equalsIgnoreCaseAndNullability(left, right) + } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType. + */ + def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + (left, right) match { + case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => + equalsIgnoreNullability(leftElementType, rightElementType) + case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => + equalsIgnoreNullability(leftKeyType, rightKeyType) && + equalsIgnoreNullability(leftValueType, rightValueType) + case (StructType(leftFields), StructType(rightFields)) => + leftFields.length == rightFields.length && + leftFields.zip(rightFields).forall { case (l, r) => + l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) + } + case (l, r) => l == r + } + } + + /** + * Compares two types, ignoring nullability of ArrayType, MapType, StructType, and ignoring case + * sensitivity of field names in StructType. + */ + def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = { + (from, to) match { + case (ArrayType(fromElement, _), ArrayType(toElement, _)) => + equalsIgnoreCaseAndNullability(fromElement, toElement) + + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + equalsIgnoreCaseAndNullability(fromKey, toKey) && + equalsIgnoreCaseAndNullability(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { case (l, r) => + l.name.equalsIgnoreCase(r.name) && + equalsIgnoreCaseAndNullability(l.dataType, r.dataType) + } + + case (fromDataType, toDataType) => fromDataType == toDataType + } + } +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala index a08223e2159..339febb374e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/PredicateUtils.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal.connector import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.expressions.{LiteralValue, NamedReference} import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.sources.{AlwaysFalse, AlwaysTrue, And, EqualNullSafe, EqualTo, Filter, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Not, Or, StringContains, StringEndsWith, StringStartsWith} @@ -44,7 +45,8 @@ private[sql] object PredicateUtils { if (values.length > 0) { if (!values.forall(_.isInstanceOf[LiteralValue[_]])) return None val dataType = values(0).asInstanceOf[LiteralValue[_]].dataType - if (!values.forall(_.asInstanceOf[LiteralValue[_]].dataType.sameType(dataType))) { + if (!values.forall(e => + DataTypeUtils.sameType(e.asInstanceOf[LiteralValue[_]].dataType, dataType))) { return None } val inValues = values.map(v => @@ -80,7 +82,7 @@ private[sql] object PredicateUtils { case "STARTS_WITH" | "ENDS_WITH" | "CONTAINS" if isValidBinaryPredicate() => val attribute = predicate.children()(0).toString val value = predicate.children()(1).asInstanceOf[LiteralValue[_]] - if (!value.dataType.sameType(StringType)) return None + if (!DataTypeUtils.sameType(value.dataType, StringType)) return None val v1Value = value.value.toString val v1Filter = predicate.name() match { case "STARTS_WITH" => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 13a7b03bc61..5e2974c2645 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -36,7 +36,6 @@ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy.{ANSI, STRICT} import org.apache.spark.sql.types.DayTimeIntervalType._ @@ -93,17 +92,6 @@ abstract class DataType extends AbstractDataType { def sql: String = simpleString.toUpperCase(Locale.ROOT) - /** - * Check if `this` and `other` are the same data type when ignoring nullability - * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). - */ - private[spark] def sameType(other: DataType): Boolean = - if (SQLConf.get.caseSensitiveAnalysis) { - DataType.equalsIgnoreNullability(this, other) - } else { - DataType.equalsIgnoreCaseAndNullability(this, other) - } - /** * Returns the same data type but set all nullability fields are true * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). @@ -117,7 +105,8 @@ abstract class DataType extends AbstractDataType { override private[sql] def defaultConcreteType: DataType = this - override private[sql] def acceptsType(other: DataType): Boolean = sameType(other) + override private[sql] def acceptsType(other: DataType): Boolean = + DataTypeUtils.sameType(this, other) private[sql] def physicalDataType: PhysicalDataType = UninitializedPhysicalType } @@ -294,25 +283,6 @@ object DataType { } } - /** - * Compares two types, ignoring nullability of ArrayType, MapType, StructType. - */ - private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { - (left, right) match { - case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => - equalsIgnoreNullability(leftElementType, rightElementType) - case (MapType(leftKeyType, leftValueType, _), MapType(rightKeyType, rightValueType, _)) => - equalsIgnoreNullability(leftKeyType, rightKeyType) && - equalsIgnoreNullability(leftValueType, rightValueType) - case (StructType(leftFields), StructType(rightFields)) => - leftFields.length == rightFields.length && - leftFields.zip(rightFields).forall { case (l, r) => - l.name == r.name && equalsIgnoreNullability(l.dataType, r.dataType) - } - case (l, r) => l == r - } - } - /** * Compares two types, ignoring compatible nullability of ArrayType, MapType, StructType. * @@ -377,30 +347,6 @@ object DataType { } } - /** - * Compares two types, ignoring nullability of ArrayType, MapType, StructType, and ignoring case - * sensitivity of field names in StructType. - */ - private[sql] def equalsIgnoreCaseAndNullability(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (ArrayType(fromElement, _), ArrayType(toElement, _)) => - equalsIgnoreCaseAndNullability(fromElement, toElement) - - case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => - equalsIgnoreCaseAndNullability(fromKey, toKey) && - equalsIgnoreCaseAndNullability(fromValue, toValue) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.length == toFields.length && - fromFields.zip(toFields).forall { case (l, r) => - l.name.equalsIgnoreCase(r.name) && - equalsIgnoreCaseAndNullability(l.dataType, r.dataType) - } - - case (fromDataType, toDataType) => fromDataType == toDataType - } - } - /** * Returns true if the two data types share the same "shape", i.e. the types * are the same, but the field names don't need to be the same. @@ -584,7 +530,7 @@ object DataType { true } - case (w, r) if w.sameType(r) && !w.isInstanceOf[NullType] => + case (w, r) if DataTypeUtils.sameType(w, r) && !w.isInstanceOf[NullType] => true case (w, r) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9ef3c4d60fd..323f7c6df32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalStructType} +import org.apache.spark.sql.catalyst.types.{DataTypeUtils, PhysicalDataType, PhysicalStructType} import org.apache.spark.sql.catalyst.util.{truncatedString, StringUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat @@ -704,7 +704,7 @@ object StructType extends AbstractDataType { // Found a missing field in `source`. newFields += field } else if (bothStructType(found.get.dataType, field.dataType) && - !found.get.dataType.sameType(field.dataType)) { + !DataTypeUtils.sameType(found.get.dataType, field.dataType)) { // Found a field with same name, but different data type. findMissingFields(found.get.dataType.asInstanceOf[StructType], field.dataType.asInstanceOf[StructType], resolver).map { missingType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala index afbc2fdb5a0..cc0f6046de0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercionSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -65,7 +66,7 @@ class AnsiTypeCoercionSuite extends TypeCoercionSuiteBase { private def shouldCastStringLiteral(to: AbstractDataType, expected: DataType): Unit = { val input = Literal("123") val castResult = AnsiTypeCoercion.implicitCast(input, to) - assert(DataType.equalsIgnoreCaseAndNullability( + assert(DataTypeUtils.equalsIgnoreCaseAndNullability( castResult.map(_.dataType).orNull, expected), s"Failed to cast String literal to $to") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index e30cce23136..1e8286dec79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.ReferenceAllColumns import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -53,7 +54,7 @@ abstract class TypeCoercionSuiteBase extends AnalysisTest { // Check null value val castNull = implicitCast(createNull(from), to) - assert(DataType.equalsIgnoreCaseAndNullability( + assert(DataTypeUtils.equalsIgnoreCaseAndNullability( castNull.map(_.dataType).orNull, expected), s"Failed to cast $from to $to") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala index 3724f313ca6..e09ae776d1c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ class SelectedFieldSuite extends AnalysisTest { @@ -534,7 +535,7 @@ class SelectedFieldSuite extends AnalysisTest { indent("but it actually selected\n") + indent(StructType(actual :: Nil).treeString) + indent("Note that expected.dataType.sameType(actual.dataType) = " + - expected.dataType.sameType(actual.dataType))) + DataTypeUtils.sameType(expected.dataType, actual.dataType))) throw ex } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala index dfe190e6ddc..3dd58dc9fc1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -96,7 +97,8 @@ class ObjectSerializerPruningSuite extends PlanTest { CreateNamedStruct(children.take(2)) }.transformUp { // Aligns null literal in `If` expression to make it resolvable. - case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) => + case i @ If(_: IsNull, Literal(null, dt), ser) + if !DataTypeUtils.sameType(dt, ser.dataType) => i.copy(trueValue = Literal(null, ser.dataType)) }.asInstanceOf[NamedExpression] @@ -126,7 +128,7 @@ class ObjectSerializerPruningSuite extends PlanTest { }.transformUp { // Aligns null literal in `If` expression to make it resolvable. case i @ If(invoke: Invoke, Literal(null, dt), ser) if invoke.functionName == "isNullAt" && - !dt.sameType(ser.dataType) => + !DataTypeUtils.sameType(dt, ser.dataType) => i.copy(trueValue = Literal(null, ser.dataType)) }.asInstanceOf[NamedExpression] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 2a487133b48..4001b546566 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -22,6 +22,7 @@ import com.fasterxml.jackson.core.JsonParseException import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.types.DataTypeTestUtils.{dayTimeIntervalTypes, yearMonthIntervalTypes} @@ -207,7 +208,7 @@ class DataTypeSuite extends SparkFunSuite { test(s"from DDL - $dataType") { val parsed = StructType.fromDDL(s"a ${dataType.sql}") val expected = new StructType().add("a", dataType) - assert(parsed.sameType(expected)) + assert(DataTypeUtils.sameType(parsed, expected)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala index d9eb0892d13..e252594540b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/StructTypeSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{caseInsensitiveResolution, caseSensitiveResolution} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DayTimeIntervalType => DT} @@ -166,22 +167,22 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { val source1 = StructType.fromDDL("c1 INT") val missing1 = StructType.fromDDL("c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>") assert(StructType.findMissingFields(source1, schema, resolver) - .exists(_.sameType(missing1))) + .exists(e => DataTypeUtils.sameType(e, missing1))) val source2 = StructType.fromDDL("c1 INT, c3 STRING") val missing2 = StructType.fromDDL("c2 STRUCT<c3: INT, c4: STRUCT<c5: INT, c6: INT>>") assert(StructType.findMissingFields(source2, schema, resolver) - .exists(_.sameType(missing2))) + .exists(e => DataTypeUtils.sameType(e, missing2))) val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT>") val missing3 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT, c6: INT>>") assert(StructType.findMissingFields(source3, schema, resolver) - .exists(_.sameType(missing3))) + .exists(e => DataTypeUtils.sameType(e, missing3))) val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, c4: STRUCT<c6: INT>>") val missing4 = StructType.fromDDL("c2 STRUCT<c4: STRUCT<c5: INT>>") assert(StructType.findMissingFields(source4, schema, resolver) - .exists(_.sameType(missing4))) + .exists(e => DataTypeUtils.sameType(e, missing4))) } test("find missing (nested) fields: array and map") { @@ -192,7 +193,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { val missing5 = StructType.fromDDL("c2 ARRAY<STRUCT<c3: INT, c4: LONG>>") assert( StructType.findMissingFields(source5, schemaWithArray, resolver) - .exists(_.sameType(missing5))) + .exists(e => DataTypeUtils.sameType(e, missing5))) val schemaWithMap1 = StructType.fromDDL( "c1 INT, c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>, c3 LONG") @@ -200,7 +201,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { val missing6 = StructType.fromDDL("c2 MAP<STRUCT<c3: INT, c4: LONG>, STRING>") assert( StructType.findMissingFields(source6, schemaWithMap1, resolver) - .exists(_.sameType(missing6))) + .exists(e => DataTypeUtils.sameType(e, missing6))) val schemaWithMap2 = StructType.fromDDL( "c1 INT, c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>, c3 STRING") @@ -208,7 +209,7 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { val missing7 = StructType.fromDDL("c2 MAP<STRING, STRUCT<c3: INT, c4: LONG>>") assert( StructType.findMissingFields(source7, schemaWithMap2, resolver) - .exists(_.sameType(missing7))) + .exists(e => DataTypeUtils.sameType(e, missing7))) // Unsupported: nested struct in array, map val source8 = StructType.fromDDL("c1 INT, c2 ARRAY<STRUCT<c3: INT>>") @@ -232,22 +233,22 @@ class StructTypeSuite extends SparkFunSuite with SQLHelper { val source1 = StructType.fromDDL("c1 INT, C2 LONG") val missing1 = StructType.fromDDL("c2 STRUCT<c3: INT, C4: STRUCT<C5: INT, c6: INT>>") assert(StructType.findMissingFields(source1, schema, resolver) - .exists(_.sameType(missing1))) + .exists(e => DataTypeUtils.sameType(e, missing1))) val source2 = StructType.fromDDL("c2 LONG") val missing2 = StructType.fromDDL("c1 INT") assert(StructType.findMissingFields(source2, schema, resolver) - .exists(_.sameType(missing2))) + .exists(e => DataTypeUtils.sameType(e, missing2))) val source3 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4: STRUCT<c5: INT>>") val missing3 = StructType.fromDDL("c2 STRUCT<C4: STRUCT<C5: INT, c6: INT>>") assert(StructType.findMissingFields(source3, schema, resolver) - .exists(_.sameType(missing3))) + .exists(e => DataTypeUtils.sameType(e, missing3))) val source4 = StructType.fromDDL("c1 INT, c2 STRUCT<c3: INT, C4: STRUCT<C5: Int>>") val missing4 = StructType.fromDDL("c2 STRUCT<C4: STRUCT<c6: INT>>") assert(StructType.findMissingFields(source4, schema, resolver) - .exists(_.sameType(missing4))) + .exists(e => DataTypeUtils.sameType(e, missing4))) } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java index 5272151acf2..e8ae6f290bc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.catalyst.types.DataTypeUtils; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.StructType; @@ -64,7 +65,7 @@ final class ParquetColumnVector { boolean isTopLevel, Object defaultValue) { DataType sparkType = column.sparkType(); - if (!sparkType.sameType(vector.dataType())) { + if (!DataTypeUtils.sameType(sparkType, vector.dataType())) { throw new IllegalArgumentException("Spark type: " + sparkType + " doesn't match the type: " + vector.dataType() + " in column vector"); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala index ba6d351761e..cbff526592f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/binaryfile/BinaryFileFormat.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriterFactory, PartitionedFile} @@ -88,7 +89,7 @@ class BinaryFileFormat extends FileFormat with DataSourceRegister { filters: Seq[Filter], options: Map[String, String], hadoopConf: Configuration): PartitionedFile => Iterator[InternalRow] = { - require(dataSchema.sameType(schema), + require(DataTypeUtils.sameType(dataSchema, schema), s""" |Binary file data source expects dataSchema: $schema, |but got: $dataSchema. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 4595ea049ef..f5e8a36de5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType} @@ -109,7 +110,7 @@ trait HashJoin extends JoinCodegenSupport { require(leftKeys.length == rightKeys.length && leftKeys.map(_.dataType) .zip(rightKeys.map(_.dataType)) - .forall(types => types._1.sameType(types._2)), + .forall(types => DataTypeUtils.sameType(types._1, types._2)), "Join keys from two sides should have same length and types") buildSide match { case BuildLeft => (leftKeys, rightKeys) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala index f3531668c8e..79773a9d534 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.api.python.PythonSQLUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ApplyInPandasWithStatePythonRunner.{InType, OutType, OutTypeForState, STATE_METADATA_SCHEMA_FROM_PYTHON_WORKER} import org.apache.spark.sql.execution.python.ApplyInPandasWithStateWriter.STATE_METADATA_SCHEMA @@ -154,7 +155,7 @@ class ApplyInPandasWithStatePythonRunner( // UDF returns a StructType column in ColumnarBatch, select the children here val structVector = batch.column(ordinal).asInstanceOf[ArrowColumnVector] val dataType = schema(ordinal).dataType.asInstanceOf[StructType] - assert(dataType.sameType(expectedType), + assert(DataTypeUtils.sameType(dataType, expectedType), s"Schema equality check failure! type from Arrow: $dataType, expected type: $expectedType") val outputVectors = dataType.indices.map(structVector.getChild) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 2445dcd519e..3ad1dc58cae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, Predicate, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ @@ -182,7 +183,7 @@ case class StreamingSymmetricHashJoinExec( require(leftKeys.length == rightKeys.length && leftKeys.map(_.dataType) .zip(rightKeys.map(_.dataType)) - .forall(types => types._1.sameType(types._2)), + .forall(types => DataTypeUtils.sameType(types._1, types._2)), "Join keys from two sides should have same length and types") private val storeConf = new StateStoreConf(conf) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 2796b1cf154..a613fe04f5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.TestingUDT.{IntervalUDT, NullData, NullUDT} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GreaterThan, Literal} import org.apache.spark.sql.catalyst.expressions.IntegralLiteralTestUtils.{negativeInt, positiveInt} import org.apache.spark.sql.catalyst.plans.logical.Filter +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{FileSourceScanLike, SimpleMode} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.datasources.FilePartition @@ -88,7 +89,7 @@ class FileBasedDataSourceSuite extends QueryTest df.write.format(format).option("header", "true").save(dir) val answerDf = spark.read.format(format).option("header", "true").load(dir) - assert(df.schema.sameType(answerDf.schema)) + assert(DataTypeUtils.sameType(df.schema, answerDf.schema)) checkAnswer(df, answerDf) } } @@ -104,7 +105,7 @@ class FileBasedDataSourceSuite extends QueryTest emptyDf.write.format(format).save(path) val df = spark.read.format(format).load(path) - assert(df.schema.sameType(emptyDf.schema)) + assert(DataTypeUtils.sameType(df.schema, emptyDf.schema)) checkAnswer(df, emptyDf) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala index 36905027cb0..33c6e4f5be6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/GroupBasedDeleteFromTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.DynamicPruningExpression +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.InSubqueryExec import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf @@ -141,14 +142,15 @@ class GroupBasedDeleteFromTableSuite extends DeleteFromTableSuiteBase { val primaryScan = collect(executedPlan) { case s: BatchScanExec => s }.head - assert(primaryScan.schema.sameType(StructType.fromDDL(primaryScanSchema))) + assert(DataTypeUtils.sameType(primaryScan.schema, StructType.fromDDL(primaryScanSchema))) primaryScan.runtimeFilters match { case Seq(DynamicPruningExpression(child: InSubqueryExec)) => val groupFilterScan = collect(child.plan) { case s: BatchScanExec => s }.head - assert(groupFilterScan.schema.sameType(StructType.fromDDL(groupFilterScanSchema))) + assert(DataTypeUtils.sameType(groupFilterScan.schema, + StructType.fromDDL(groupFilterScanSchema))) case _ => fail("could not find group filter scan") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index bd9c79e5b96..bf496d6db21 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.SchemaPruningTest import org.apache.spark.sql.catalyst.expressions.Concat import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.Expand +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.functions._ @@ -851,7 +852,7 @@ abstract class SchemaPruningSuite protected val schemaEquality = new Equality[StructType] { override def areEqual(a: StructType, b: Any): Boolean = b match { - case otherType: StructType => a.sameType(otherType) + case otherType: StructType => DataTypeUtils.sameType(a, otherType) case _ => false } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 43626237b13..183c4f71df6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.localDateTimeToMicros import org.apache.spark.sql.execution.datasources._ @@ -1101,7 +1102,7 @@ abstract class ParquetPartitionDiscoverySuite val input = spark.read.parquet(path.getAbsolutePath).select("id", "date_month", "date_hour", "date_t_hour", "data") - assert(data.schema.sameType(input.schema)) + assert(DataTypeUtils.sameType(data.schema, input.schema)) checkAnswer(input, data) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 9af34ba3b20..90e8f9b9d0e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{PartitioningUtils, SourceOptions} @@ -815,7 +816,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val partColumnNames = getPartitionColumnsFromTableProperties(table) val reorderedSchema = reorderSchema(schema = schemaFromTableProps, partColumnNames) - if (DataType.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema) || + if (DataTypeUtils.equalsIgnoreCaseAndNullability(reorderedSchema, table.schema) || options.respectSparkSchema) { hiveTable.copy( schema = reorderedSchema, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 12b570e8186..4d13c02c503 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions} import org.apache.spark.sql.internal.SQLConf @@ -89,7 +90,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // we will use the cached relation. val useCached = relation.location.rootPaths.toSet == pathsInMetastore.toSet && - logical.schema.sameType(schemaInMetastore) && + DataTypeUtils.sameType(logical.schema, schemaInMetastore) && // We don't support hive bucketed tables. This function `getCached` is only used for // converting supported Hive tables to data source tables. relation.bucketSpec.isEmpty && diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index e76ef4725f7..d1b94815c5d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -27,6 +27,7 @@ import org.apache.logging.log4j.Level import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.execution.command.CreateTableCommand import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -776,7 +777,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }) // Make sure partition columns are correctly stored in metastore. assert( - expectedPartitionColumns.sameType(actualPartitionColumns), + DataTypeUtils.sameType(expectedPartitionColumns, actualPartitionColumns), s"Partitions columns stored in metastore $actualPartitionColumns is not the " + s"partition columns defined by the saveAsTable operation $expectedPartitionColumns.") @@ -818,7 +819,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }) // Make sure bucketBy columns are correctly stored in metastore. assert( - expectedBucketByColumns.sameType(actualBucketByColumns), + DataTypeUtils.sameType(expectedBucketByColumns, actualBucketByColumns), s"Partitions columns stored in metastore $actualBucketByColumns is not the " + s"partition columns defined by the saveAsTable operation $expectedBucketByColumns.") @@ -829,7 +830,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv }) // Make sure sortBy columns are correctly stored in metastore. assert( - expectedSortByColumns.sameType(actualSortByColumns), + DataTypeUtils.sameType(expectedSortByColumns, actualSortByColumns), s"Partitions columns stored in metastore $actualSortByColumns is not the " + s"partition columns defined by the saveAsTable operation $expectedSortByColumns.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 8b7f7ade560..c81c4649107 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType, HiveTableRelation} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.execution.SQLViewSuite import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} @@ -143,7 +144,7 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { // Check the output rows. checkAnswer(df, Row(1, 2)) // Check the output schema. - assert(df.schema.sameType(view.schema)) + assert(DataTypeUtils.sameType(df.schema, view.schema)) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org