This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 6161bf44f40 [SPARK-44353][CONNECT][SQL] Remove StructType.toAttributes 6161bf44f40 is described below commit 6161bf44f40f8146ea4c115c788fd4eaeb128769 Author: Herman van Hovell <her...@databricks.com> AuthorDate: Wed Jul 12 02:27:16 2023 -0400 [SPARK-44353][CONNECT][SQL] Remove StructType.toAttributes ### What changes were proposed in this pull request? This PR removes StructType.toAttribute. It is being replace by DataTypeUtils.toAttribute(..). ### Why are the changes needed? We want to move the DataType hierarchy into sql/api. We need to remove any catalyst specific API. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #41925 from hvanhovell/SPARK-44353. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../spark/sql/connect/client/SparkResult.scala | 4 +- .../sql/connect/planner/SparkConnectPlanner.scala | 18 +++--- .../connect/planner/SparkConnectPlannerSuite.scala | 19 +++++- .../connect/planner/SparkConnectProtoSuite.scala | 12 ++-- .../spark/sql/kafka010/KafkaBatchWrite.scala | 5 +- .../spark/sql/kafka010/KafkaStreamingWrite.scala | 5 +- .../spark/sql/catalyst/analysis/Analyzer.scala | 7 ++- .../sql/catalyst/analysis/AssignmentUtils.scala | 3 +- .../catalyst/analysis/ResolveInlineTables.scala | 2 +- .../catalyst/analysis/TableOutputResolver.scala | 5 +- .../sql/catalyst/analysis/v2ResolutionPlans.scala | 3 +- .../sql/catalyst/encoders/ExpressionEncoder.scala | 3 +- .../catalyst/optimizer/NestedColumnAliasing.scala | 4 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 3 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 3 +- .../sql/catalyst/plans/logical/LocalRelation.scala | 7 ++- .../plans/logical/basicLogicalOperators.scala | 3 +- .../spark/sql/catalyst/plans/logical/object.scala | 7 ++- .../spark/sql/catalyst/types/DataTypeUtils.scala | 7 +++ .../spark/sql/catalyst/util/GeneratedColumn.scala | 4 +- .../sql/connector/catalog/CatalogV2Implicits.scala | 3 +- .../datasources/v2/DataSourceV2Implicits.scala | 3 +- .../datasources/v2/DataSourceV2Relation.scala | 3 +- .../org/apache/spark/sql/types/StructType.scala | 6 +- .../scala/org/apache/spark/sql/HashBenchmark.scala | 3 +- .../spark/sql/UnsafeProjectionBenchmark.scala | 9 +-- .../sql/catalyst/analysis/AnalysisSuite.scala | 13 ++-- .../CreateTablePartitioningValidationSuite.scala | 3 +- .../catalyst/analysis/V2WriteAnalysisSuite.scala | 72 ++++++++++++---------- .../catalyst/encoders/EncoderResolutionSuite.scala | 5 +- .../catalyst/encoders/ExpressionEncoderSuite.scala | 3 +- .../catalyst/expressions/SelectedFieldSuite.scala | 4 +- .../main/scala/org/apache/spark/sql/Dataset.scala | 5 +- .../spark/sql/RelationalGroupedDataset.scala | 7 ++- .../scala/org/apache/spark/sql/SparkSession.scala | 11 ++-- .../execution/aggregate/HashAggregateExec.scala | 5 +- .../spark/sql/execution/aggregate/udaf.scala | 5 +- .../sql/execution/arrow/ArrowConverters.scala | 3 +- .../spark/sql/execution/command/SetCommand.scala | 5 +- .../apache/spark/sql/execution/command/ddl.scala | 5 +- .../spark/sql/execution/command/functions.scala | 3 +- .../sql/execution/datasources/FileFormat.scala | 3 +- .../execution/datasources/LogicalRelation.scala | 5 +- .../datasources/SaveIntoDataSourceCommand.scala | 3 +- .../sql/execution/datasources/SchemaPruning.scala | 4 +- .../execution/datasources/orc/OrcFileFormat.scala | 3 +- .../datasources/parquet/ParquetFileFormat.scala | 3 +- .../spark/sql/execution/datasources/rules.scala | 3 +- .../sql/execution/datasources/v2/FileScan.scala | 7 ++- .../sql/execution/datasources/v2/FileWrite.scala | 3 +- .../v2/PartitionReaderWithPartitionValues.scala | 3 +- .../execution/datasources/v2/PushDownUtils.scala | 3 +- .../datasources/v2/V2ScanRelationPushDown.scala | 3 +- .../execution/streaming/StreamingRelation.scala | 7 ++- .../spark/sql/execution/streaming/memory.scala | 3 +- .../streaming/sources/ConsoleStreamingWrite.scala | 3 +- .../streaming/sources/ForeachWriterTable.scala | 3 +- .../state/FlatMapGroupsWithStateExecHelper.scala | 5 +- .../state/SymmetricHashJoinStateManager.scala | 3 +- .../apache/spark/sql/internal/CatalogImpl.scala | 3 +- .../spark/sql/streaming/DataStreamReader.scala | 3 +- .../spark/sql/streaming/DataStreamWriter.scala | 4 +- .../sql/connector/TableCapabilityCheckSuite.scala | 5 +- .../connector/V2CommandsCaseSensitivitySuite.scala | 3 +- .../spark/sql/execution/GroupedIteratorSuite.scala | 7 ++- .../spark/sql/execution/debug/DebuggingSuite.scala | 3 +- .../sql/execution/streaming/MemorySinkSuite.scala | 3 +- .../streaming/MergingSessionsIteratorSuite.scala | 3 +- ...ngSortWithSessionWindowStateIteratorSuite.scala | 3 +- .../streaming/UpdatingSessionsIteratorSuite.scala | 3 +- .../StreamingAggregationStateManagerSuite.scala | 3 +- .../StreamingSessionWindowStateManagerSuite.scala | 3 +- .../state/SymmetricHashJoinStateManagerSuite.scala | 3 +- .../execution/ui/SQLAppStatusListenerSuite.scala | 5 +- .../spark/sql/sources/BucketedReadSuite.scala | 3 +- .../apache/spark/sql/streaming/StreamSuite.scala | 3 +- .../spark/sql/streaming/StreamingQuerySuite.scala | 5 +- .../spark/sql/sources/SimpleTextRelation.scala | 6 +- 78 files changed, 273 insertions(+), 165 deletions(-) diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala index a6ed31c1869..d33f405ee94 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder, UnboundRowEncoder} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.types.{DataType, StructType} @@ -95,7 +96,8 @@ private[sql] class SparkResult[T]( } // TODO: create encoders that directly operate on arrow vectors. if (boundEncoder == null) { - boundEncoder = createEncoder(structType).resolveAndBind(structType.toAttributes) + boundEncoder = createEncoder(structType) + .resolveAndBind(DataTypeUtils.toAttributes(structType)) } while (reader.loadNextBatch()) { val rowCount = root.getRowCount diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 5fd5f7d4c77..8b1e6779a63 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -51,6 +51,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, L import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.{AppendColumns, CoGroup, CollectMetrics, CommandResult, Deduplicate, DeduplicateWithinWatermark, DeserializeToObject, Except, FlatMapGroupsWithState, Intersect, JoinWith, LocalRelation, LogicalGroupState, LogicalPlan, MapGroups, MapPartitions, Project, Sample, SerializeFromObject, Sort, SubqueryAlias, TypedFilter, Union, Unpivot, UnresolvedHint} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, UdfPacket} import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE @@ -508,13 +509,13 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { case PythonEvalType.SQL_MAP_PANDAS_ITER_UDF => logical.MapInPandas( pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]), baseRel, isBarrier) case PythonEvalType.SQL_MAP_ARROW_ITER_UDF => logical.PythonMapInArrow( pythonUdf, - pythonUdf.dataType.asInstanceOf[StructType].toAttributes, + DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]), baseRel, isBarrier) case _ => @@ -638,7 +639,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { ds.groupingAttributes, ds.dataAttributes, udf.inputDeserializer(ds.groupingAttributes), - LocalRelation(initialDs.vEncoder.schema.toAttributes), // empty data set + LocalRelation(initialDs.vEncoder.schema), // empty data set ds.analyzed) } SerializeFromObject(udf.outputNamedExpression, flatMapGroupsWithState) @@ -1106,7 +1107,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { if (structType == null) { throw InvalidPlanInput(s"Input data for LocalRelation does not produce a schema.") } - val attributes = structType.toAttributes + val attributes = DataTypeUtils.toAttributes(structType) val proj = UnsafeProjection.create(attributes, attributes) val data = rows.map(proj) @@ -1133,22 +1134,23 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging { val project = Dataset .ofRows( session, - logicalPlan = - logical.LocalRelation(normalize(structType).asInstanceOf[StructType].toAttributes)) + logicalPlan = logical.LocalRelation(normalize(structType).asInstanceOf[StructType])) .toDF(normalized.names: _*) .to(normalized) .logicalPlan .asInstanceOf[Project] val proj = UnsafeProjection.create(project.projectList, project.child.output) - logical.LocalRelation(schema.toAttributes, data.map(proj).map(_.copy()).toSeq) + logical.LocalRelation( + DataTypeUtils.toAttributes(schema), + data.map(proj).map(_.copy()).toSeq) } } else { if (schema == null) { throw InvalidPlanInput( s"Schema for LocalRelation is required when the input data is not provided.") } - LocalRelation(schema.toAttributes, data = Seq.empty) + LocalRelation(schema) } } diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala index 14fdc8c0073..a10540676b0 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.expressions.{AttributeReference, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto import org.apache.spark.sql.connect.service.SessionHolder @@ -71,6 +72,20 @@ trait SparkConnectPlanTest extends SharedSparkSession { .build()) .build() + /** + * Creates a local relation for testing purposes. The local relation is mapped to it's + * equivalent in Catalyst and can be easily used for planner testing. + * + * @param schema + * the schema of LocalRelation + * @param data + * the data of LocalRelation + * @return + */ + def createLocalRelationProto(schema: StructType, data: Seq[InternalRow]): proto.Relation = { + createLocalRelationProto(DataTypeUtils.toAttributes(schema), data) + } + /** * Creates a local relation for testing purposes. The local relation is mapped to it's * equivalent in Catalyst and can be easily used for planner testing. @@ -456,7 +471,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { proj(row).copy() } - val localRelation = createLocalRelationProto(schema.toAttributes, inputRows) + val localRelation = createLocalRelationProto(schema, inputRows) val df = Dataset.ofRows(spark, transform(localRelation)) val array = df.collect() assertResult(10)(array.length) @@ -599,7 +614,7 @@ class SparkConnectPlannerSuite extends SparkFunSuite with SparkConnectPlanTest { proj(row).copy() } - val localRelation = createLocalRelationProto(schema.toAttributes, inputRows) + val localRelation = createLocalRelationProto(schema, inputRows) val project = proto.Project diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala index 2bb998a29b9..82941d8d72e 100644 --- a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectProtoSuite.scala @@ -710,7 +710,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { proj(row).copy() } - val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows) + val localRelationV2 = createLocalRelationProto(schema, inputRows) val cmd = localRelationV2.writeV2( tableName = Some("testcat.table_name"), @@ -740,7 +740,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { proj(row).copy() } - val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows) + val localRelationV2 = createLocalRelationProto(schema, inputRows) val cmd = localRelationV2.writeV2( tableName = Some("testcat.table_name"), @@ -778,7 +778,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { proj(row).copy() } - val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows) + val localRelationV2 = createLocalRelationProto(schema, inputRows) spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") @@ -817,8 +817,8 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { proj(row).copy() } - val localRelation1V2 = createLocalRelationProto(schema.toAttributes, inputRows1) - val localRelation2V2 = createLocalRelationProto(schema.toAttributes, inputRows2) + val localRelation1V2 = createLocalRelationProto(schema, inputRows1) + val localRelation2V2 = createLocalRelationProto(schema, inputRows2) spark.sql( "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") @@ -865,7 +865,7 @@ class SparkConnectProtoSuite extends PlanTest with SparkConnectPlanTest { proj(row).copy() } - val localRelationV2 = createLocalRelationProto(schema.toAttributes, inputRows) + val localRelationV2 = createLocalRelationProto(schema, inputRows) spark.sql( "CREATE TABLE testcat.table_name (id bigint, data string) USING foo PARTITIONED BY (id)") diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala index 56c0fdd7c35..002da3c5132 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage} import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.types.StructType @@ -38,7 +39,7 @@ private[kafka010] class KafkaBatchWrite( schema: StructType) extends BatchWrite { - validateQuery(schema.toAttributes, producerParams, topic) + validateQuery(DataTypeUtils.toAttributes(schema), producerParams, topic) override def createBatchWriterFactory(info: PhysicalWriteInfo): KafkaBatchWriterFactory = KafkaBatchWriterFactory(topic, producerParams, schema) @@ -62,6 +63,6 @@ private case class KafkaBatchWriterFactory( extends DataWriterFactory { override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - new KafkaDataWriter(topic, producerParams, schema.toAttributes) + new KafkaDataWriter(topic, producerParams, DataTypeUtils.toAttributes(schema)) } } diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index db719966267..1fdf1b9293d 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.write.{DataWriter, PhysicalWriteInfo, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery @@ -39,7 +40,7 @@ private[kafka010] class KafkaStreamingWrite( schema: StructType) extends StreamingWrite { - validateQuery(schema.toAttributes, producerParams, topic) + validateQuery(DataTypeUtils.toAttributes(schema), producerParams, topic) override def createStreamingWriterFactory( info: PhysicalWriteInfo): KafkaStreamWriterFactory = @@ -69,6 +70,6 @@ private case class KafkaStreamWriterFactory( partitionId: Int, taskId: Long, epochId: Long): DataWriter[InternalRow] = { - new KafkaDataWriter(topic, producerParams, schema.toAttributes) + new KafkaDataWriter(topic, producerParams, DataTypeUtils.toAttributes(schema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7c91c2ee451..55433ea04b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{toPrettySQL, AUTO_GENERATED_ALIAS, CharVarcharUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.{View => _, _} @@ -2859,7 +2860,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private[analysis] def makeGeneratorOutput( generator: Generator, names: Seq[String]): Seq[Attribute] = { - val elementAttrs = generator.elementSchema.toAttributes + val elementAttrs = DataTypeUtils.toAttributes(generator.elementSchema) if (names.length == elementAttrs.length) { names.zip(elementAttrs).map { @@ -3240,11 +3241,11 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val dataType = udf.children(i).dataType encOpt.map { enc => val attrs = if (enc.isSerializedAsStructForTopLevel) { - dataType.asInstanceOf[StructType].toAttributes + DataTypeUtils.toAttributes(dataType.asInstanceOf[StructType]) } else { // the field name doesn't matter here, so we use // a simple literal to avoid any overhead - new StructType().add("input", dataType).toAttributes + DataTypeUtils.toAttribute(StructField("input", dataType)) :: Nil } enc.resolveAndBind(attrs) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala index fa953c90532..c9ee68a0dc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AssignmentUtils.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.{Attribute, CreateNamedStruct, Expression, GetStructField, Literal} import org.apache.spark.sql.catalyst.plans.logical.Assignment +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -177,7 +178,7 @@ object AssignmentUtils extends SQLConfHelper with CastSupport { col.dataType match { case structType: StructType => - val fieldAttrs = structType.toAttributes + val fieldAttrs = DataTypeUtils.toAttributes(structType) val fieldExprs = structType.fields.zipWithIndex.map { case (field, ordinal) => GetStructField(colExpr, ordinal, Some(field.name)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 2be1c1b7b08..cf706171cd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -101,7 +101,7 @@ object ResolveInlineTables extends Rule[LogicalPlan] with CastSupport with Alias } StructField(name, tpe, nullable = column.exists(_.nullable)) } - val attributes = StructType(fields).toAttributes + val attributes = DataTypeUtils.toAttributes(StructType(fields)) assert(fields.size == table.names.size) val newRows: Seq[InternalRow] = table.rows.map { row => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 6718020685b..9c94437dbc7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.getDefaultValueExprOrNullLit import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -313,9 +314,9 @@ object TableOutputResolver { Alias(GetStructField(nullCheckedInput, i, Some(f.name)), f.name)() } val resolved = if (byName) { - reorderColumnsByName(fields, expectedType.toAttributes, conf, addError, colPath) + reorderColumnsByName(fields, toAttributes(expectedType), conf, addError, colPath) } else { - resolveColumnsByPosition(fields, expectedType.toAttributes, conf, addError, colPath) + resolveColumnsByPosition(fields, toAttributes(expectedType), conf, addError, colPath) } if (resolved.length == expectedType.length) { val struct = CreateStruct(resolved) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index cbc08f760fb..04d6337376c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression, Unevaluable} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_FUNC} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, Table, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ @@ -162,7 +163,7 @@ object ResolvedTable { identifier: Identifier, table: Table): ResolvedTable = { val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.columns.asSchema) - ResolvedTable(catalog, identifier, table, schema.toAttributes) + ResolvedTable(catalog, identifier, table, toAttributes(schema)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cfcc1959a3d..83a018bafe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.{ReassignLambdaVariableID, SimplifyCasts} import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LeafNode, LocalRelation} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{ObjectType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -333,7 +334,7 @@ case class ExpressionEncoder[T]( * this method to do resolution and binding outside of query framework. */ def resolveAndBind( - attrs: Seq[Attribute] = schema.toAttributes, + attrs: Seq[Attribute] = DataTypeUtils.toAttributes(schema), analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = { val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this) val analyzedPlan = analyzer.execute(dummyPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 579afa0439a..5d4fcf772b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer -import scala.collection import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -433,7 +433,7 @@ object GeneratorNestedColumnAliasing { // As we change the child of the generator, its output data type must be updated. val updatedGeneratorOutput = rewrittenG.generatorOutput - .zip(rewrittenG.generator.elementSchema.toAttributes) + .zip(toAttributes(rewrittenG.generator.elementSchema)) .map { case (oldAttr, newAttr) => newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9e9e606d438..fd2ea96a296 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -2422,7 +2423,7 @@ object GenerateOptimization extends Rule[LogicalPlan] { } // As we change the child of the generator, its output data type must be updated. val updatedGeneratorOutput = rewrittenG.generatorOutput - .zip(rewrittenG.generator.elementSchema.toAttributes) + .zip(toAttributes(rewrittenG.generator.elementSchema)) .map { case (oldAttr, newAttr) => newAttr.withExprId(oldAttr.exprId).withName(oldAttr.name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 0eeef48b071..1fdfbb97abc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils, GeneratedColumn, IntervalUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{convertSpecialDate, convertSpecialTimestamp, convertSpecialTimestampNTZ, getZoneId, stringToDate, stringToTimestamp, stringToTimestampWithoutTimeZone} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsNamespaces, TableCatalog} @@ -785,7 +786,7 @@ class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging { // Create the attributes. val (attributes, schemaLess) = if (transformClause.colTypeList != null) { // Typed return columns. - (createSchema(transformClause.colTypeList).toAttributes, false) + (DataTypeUtils.toAttributes(createSchema(transformClause.colTypeList)), false) } else if (transformClause.identifierSeq != null) { // Untyped return columns. val attrs = visitIdentifierSeq(transformClause.identifierSeq).map { name => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index 305f13481aa..78c1087a1b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, TreePattern} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.Utils @@ -30,7 +31,11 @@ object LocalRelation { def apply(output: Attribute*): LocalRelation = new LocalRelation(output) def apply(output1: StructField, output: StructField*): LocalRelation = { - new LocalRelation(StructType(output1 +: output).toAttributes) + apply(StructType(output1 +: output)) + } + + def apply(schema: StructType): LocalRelation = { + new LocalRelation(DataTypeUtils.toAttributes(schema)) } def fromExternalRows(output: Seq[Attribute], data: Seq[Row]): LocalRelation = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c5ac0304841..f8ba042009b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.catalyst.trees.TreeNodeTag import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf @@ -922,7 +923,7 @@ object Range { } def getOutputAttrs: Seq[Attribute] = { - StructType(Array(StructField("id", LongType, nullable = false))).toAttributes + toAttributes(StructType(Array(StructField("id", LongType, nullable = false)))) } private def typeCoercion: TypeCoercionBase = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 980295f5e0d..75dab55dccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.{InnerLike, LeftAnti, LeftSemi, ReferenceAllColumns} import org.apache.spark.sql.catalyst.trees.TreePattern._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} @@ -137,7 +138,7 @@ object MapPartitionsInR { packageNames, broadcastVars, encoder.schema, - schema.toAttributes, + toAttributes(schema), child) } else { val deserialized = CatalystSerde.deserialize(child)(encoder) @@ -481,7 +482,7 @@ object FlatMapGroupsWithState { groupingAttributes, dataAttributes, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), - LocalRelation(stateEncoder.schema.toAttributes), // empty data set + LocalRelation(stateEncoder.schema), // empty data set child ) CatalystSerde.serialize[U](mapped) @@ -594,7 +595,7 @@ object FlatMapGroupsInR { packageNames, broadcastVars, inputSchema, - schema.toAttributes, + toAttributes(schema), UnresolvedDeserializer(keyDeserializer, groupingAttributes), groupingAttributes, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala index 9689eca73d5..d0df8c3270e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -186,6 +186,13 @@ object DataTypeUtils { def toAttribute(field: StructField): AttributeReference = AttributeReference(field.name, field.dataType, field.nullable, field.metadata)() + /** + * Convert a [[StructType]] into a Seq of [[AttributeReference]]. + */ + def toAttributes(schema: StructType): Seq[AttributeReference] = { + schema.map(toAttribute) + } + /** * Convert a literal to a DecimalType. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala index 5dd278e3fea..28ddc16cf6b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GeneratedColumn.scala @@ -116,8 +116,8 @@ object GeneratedColumn { val allowedBaseColumns = schema .filterNot(_.name == fieldName) // Can't reference itself .filterNot(isGeneratedColumn) // Can't reference other generated columns - val relation = new LocalRelation(CharVarcharUtils.replaceCharVarcharWithStringInSchema( - StructType(allowedBaseColumns)).toAttributes) + val relation = LocalRelation( + CharVarcharUtils.replaceCharVarcharWithStringInSchema(StructType(allowedBaseColumns))) val plan = try { val analyzer: Analyzer = GeneratedColumnAnalyzer val analyzed = analyzer.execute(Project(Seq(Alias(parsed, fieldName)()), relation)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 12858887bb5..4ca926abad3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.quoteIfNeeded import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, IdentityTransform, LogicalExpressions, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -187,7 +188,7 @@ private[sql] object CatalogV2Implicits { implicit class ColumnsHelper(columns: Array[Column]) { def asSchema: StructType = CatalogV2Util.v2ColumnsToStructType(columns) - def toAttributes: Seq[AttributeReference] = asSchema.toAttributes + def toAttributes: Seq[AttributeReference] = DataTypeUtils.toAttributes(asSchema) } def parseColumnPath(name: String): Seq[String] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index ad8320da1fd..bb55eb0f41f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.analysis.{PartitionSpec, ResolvedPartitionSpec, UnresolvedPartitionSpec} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, MetadataAttribute} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.{MetadataColumn, SupportsAtomicPartitionManagement, SupportsDeleteV2, SupportsPartitionManagement, SupportsRead, SupportsWrite, Table, TableCapability, TruncatableTable} import org.apache.spark.sql.connector.write.RowLevelOperationTable import org.apache.spark.sql.errors.QueryCompilationErrors @@ -109,7 +110,7 @@ object DataSourceV2Implicits { StructType(fields) } - def toAttributes: Seq[AttributeReference] = asStruct.toAttributes + def toAttributes: Seq[AttributeReference] = DataTypeUtils.toAttributes(asStruct) } implicit class OptionsHelper(options: Map[String, String]) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index c868f20fed4..92638a15287 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Expression, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, ExposesMetadataColumns, Histogram, HistogramBin, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog, Identifier, SupportsMetadataColumns, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics} @@ -194,7 +195,7 @@ object DataSourceV2Relation { // The v2 source may return schema containing char/varchar type. We replace char/varchar // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.columns.asSchema) - DataSourceV2Relation(table, schema.toAttributes, catalog, identifier, options) + DataSourceV2Relation(table, toAttributes(schema), catalog, identifier, options) } def create( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d92ba390f25..df903c57887 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -26,10 +26,9 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable import org.apache.spark.sql.SqlApiConf import org.apache.spark.sql.catalyst.analysis.Resolver -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.parser.{DataTypeParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{SparkStringUtils, StringConcat} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.util.collection.Utils @@ -379,9 +378,6 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru findFieldInStruct(this, fieldNames, Nil) } - protected[sql] def toAttributes: Seq[AttributeReference] = - map(field => DataTypeUtils.toAttribute(field)) - def treeString: String = treeString(Int.MaxValue) def treeString(maxDepth: Int): String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala index 8e96faace52..13ab7e2a705 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/HashBenchmark.scala @@ -21,6 +21,7 @@ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +43,7 @@ object HashBenchmark extends BenchmarkBase { runBenchmark(name) { val generator = RandomDataGenerator.forType(schema, nullable = false).get val toRow = RowEncoder(schema).createSerializer() - val attrs = schema.toAttributes + val attrs = DataTypeUtils.toAttributes(schema) val safeProjection = GenerateSafeProjection.generate(attrs, attrs) val rows = (1 to numRows).map(_ => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala index 07179a20cd0..b7704eb211f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/UnsafeProjectionBenchmark.scala @@ -21,6 +21,7 @@ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.types._ /** @@ -50,7 +51,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { val benchmark = new Benchmark("unsafe projection", iters * numRows.toLong, output = output) val schema1 = new StructType().add("l", LongType, false) - val attrs1 = schema1.toAttributes + val attrs1 = DataTypeUtils.toAttributes(schema1) val rows1 = generateRows(schema1, numRows) val projection1 = UnsafeProjection.create(attrs1, attrs1) @@ -66,7 +67,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { } val schema2 = new StructType().add("l", LongType, true) - val attrs2 = schema2.toAttributes + val attrs2 = DataTypeUtils.toAttributes(schema2) val rows2 = generateRows(schema2, numRows) val projection2 = UnsafeProjection.create(attrs2, attrs2) @@ -89,7 +90,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { .add("long", LongType, false) .add("float", FloatType, false) .add("double", DoubleType, false) - val attrs3 = schema3.toAttributes + val attrs3 = DataTypeUtils.toAttributes(schema3) val rows3 = generateRows(schema3, numRows) val projection3 = UnsafeProjection.create(attrs3, attrs3) @@ -112,7 +113,7 @@ object UnsafeProjectionBenchmark extends BenchmarkBase { .add("long", LongType, true) .add("float", FloatType, true) .add("double", DoubleType, true) - val attrs4 = schema4.toAttributes + val attrs4 = DataTypeUtils.toAttributes(schema4) val rows4 = generateRows(schema4, numRows) val projection4 = UnsafeProjection.create(attrs4, attrs4) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6a08600fb41..2bb3439da01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser.parsePlan import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.connector.catalog.InMemoryTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation @@ -64,7 +65,11 @@ class AnalysisSuite extends AnalysisTest with Matchers { val table = new InMemoryTable("t", schema, Array.empty, Map.empty[String, String].asJava) intercept[IllegalStateException] { DataSourceV2Relation( - table, schema.toAttributes, None, None, CaseInsensitiveStringMap.empty()).analyze + table, + DataTypeUtils.toAttributes(schema), + None, + None, + CaseInsensitiveStringMap.empty()).analyze } } } @@ -646,7 +651,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq.empty, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, true) - val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val output = DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]) val project = Project(Seq(UnresolvedAttribute("a")), testRelation) val flatMapGroupsInPandas = FlatMapGroupsInPandas( Seq(UnresolvedAttribute("a")), pythonUdf, output, project) @@ -663,7 +668,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq.empty, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, true) - val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val output = DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]) val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation) val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2) val flatMapGroupsInPandas = FlatMapCoGroupsInPandas( @@ -686,7 +691,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { Seq.empty, PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, true) - val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val output = DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]) val project = Project(Seq(UnresolvedAttribute("a")), testRelation) val mapInPandas = MapInPandas( pythonUdf, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala index 4158dc9e273..882bed48f8f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -21,6 +21,7 @@ import java.util import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode, OptionList, UnresolvedTableSpec} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.{InMemoryTableCatalog, Table, TableCapability, TableCatalog} import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} @@ -142,7 +143,7 @@ private[sql] object CreateTablePartitioningValidationSuite { private[sql] case object TestRelation2 extends LeafNode with NamedRelation { override def name: String = "source_relation" override def output: Seq[AttributeReference] = - CreateTablePartitioningValidationSuite.schema.toAttributes + DataTypeUtils.toAttributes(CreateTablePartitioningValidationSuite.schema) } private[sql] case object TestTable2 extends Table { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala index 5b51730e759..0a97b130c80 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/V2WriteAnalysisSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cas import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types._ @@ -109,12 +110,21 @@ case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with N override def name: String = "table-name" } +object TestRelation { + def apply(schema: StructType): TestRelation = apply(DataTypeUtils.toAttributes(schema)) +} + case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "test-name" override def skipSchemaResolution: Boolean = true } +object TestRelationAcceptAnySchema { + def apply(schema: StructType): TestRelationAcceptAnySchema = + apply(DataTypeUtils.toAttributes(schema)) +} + abstract class V2ANSIWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { // For Ansi store assignment policy, expression `AnsiCast` is used instead of `Cast`. @@ -176,11 +186,11 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { test("byName: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), - StructField("y", DoubleType))).toAttributes) + StructField("y", DoubleType)))) val query = TestRelation(StructType(Seq( StructField("x", DoubleType), - StructField("b", FloatType))).toAttributes) + StructField("b", FloatType)))) val parsedPlan = byName(xRequiredTable, query) @@ -194,7 +204,7 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) val parsedPlan = byPosition(table, widerTable) @@ -207,11 +217,11 @@ abstract class V2StrictWriteAnalysisSuiteBase extends V2WriteAnalysisSuiteBase { test("byPosition: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), - StructField("y", FloatType))).toAttributes) + StructField("y", FloatType)))) val query = TestRelation(StructType(Seq( StructField("x", DoubleType), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) val parsedPlan = byPosition(xRequiredTable, query) @@ -239,8 +249,8 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("SPARK-33136: output resolved on complex types for V2 write commands") { def assertTypeCompatibility(name: String, fromType: DataType, toType: DataType): Unit = { - val table = TestRelation(StructType(Seq(StructField("a", toType))).toAttributes) - val query = TestRelation(StructType(Seq(StructField("a", fromType))).toAttributes) + val table = TestRelation(StructType(Seq(StructField("a", toType)))) + val query = TestRelation(StructType(Seq(StructField("a", fromType)))) val parsedPlan = byName(table, query) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) @@ -302,14 +312,14 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("skipSchemaResolution should still require query to be resolved") { val table = TestRelationAcceptAnySchema(StructType(Seq( StructField("a", FloatType), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) val query = UnresolvedRelation(Seq("t")) val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) } test("byName: basic behavior") { - val query = TestRelation(table.schema.toAttributes) + val query = TestRelation(table.schema) val parsedPlan = byName(table, query) @@ -320,7 +330,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byName: does not match by position") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), - StructField("b", FloatType))).toAttributes) + StructField("b", FloatType)))) val parsedPlan = byName(table, query) @@ -333,7 +343,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byName: case sensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! - StructField("y", FloatType))).toAttributes) + StructField("y", FloatType)))) val parsedPlan = byName(table, query) @@ -347,7 +357,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byName: case insensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! - StructField("y", FloatType))).toAttributes) + StructField("y", FloatType)))) val X = query.output.head val y = query.output.last @@ -364,7 +374,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), - StructField("x", FloatType))).toAttributes) + StructField("x", FloatType)))) val y = query.output.head val x = query.output.last @@ -395,7 +405,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byName: missing required columns cause failure and are identified by name") { // missing required field x val query = TestRelation(StructType(Seq( - StructField("y", FloatType, nullable = false))).toAttributes) + StructField("y", FloatType, nullable = false)))) val parsedPlan = byName(requiredTable, query) @@ -408,7 +418,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byName: missing optional columns cause failure and are identified by name") { // missing optional field x val query = TestRelation(StructType(Seq( - StructField("y", FloatType))).toAttributes) + StructField("y", FloatType)))) val parsedPlan = byName(table, query) @@ -438,7 +448,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val query = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType), - StructField("z", FloatType))).toAttributes) + StructField("z", FloatType)))) val parsedPlan = byName(table, query) @@ -466,7 +476,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), - StructField("b", FloatType))).toAttributes) + StructField("b", FloatType)))) val a = query.output.head val b = query.output.last @@ -487,7 +497,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), - StructField("x", FloatType))).toAttributes) + StructField("x", FloatType)))) val y = query.output.head val x = query.output.last @@ -522,7 +532,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byPosition: missing required columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( - StructField("y", FloatType, nullable = false))).toAttributes) + StructField("y", FloatType, nullable = false)))) val parsedPlan = byPosition(requiredTable, query) @@ -540,7 +550,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byPosition: missing optional columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( - StructField("y", FloatType))).toAttributes) + StructField("y", FloatType)))) val parsedPlan = byPosition(table, query) @@ -558,7 +568,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) val x = table.output.head val y = table.output.last @@ -579,7 +589,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType), - StructField("c", FloatType))).toAttributes) + StructField("c", FloatType)))) val parsedPlan = byName(table, query) @@ -597,10 +607,10 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { test("bypass output column resolution") { val table = TestRelationAcceptAnySchema(StructType(Seq( StructField("a", FloatType, nullable = false), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) val query = TestRelation(StructType(Seq( - StructField("s", StringType))).toAttributes) + StructField("s", StringType)))) withClue("byName") { val parsedPlan = byName(table, query) @@ -619,13 +629,13 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val tableWithStructCol = TestRelation( new StructType().add( "col", new StructType().add("a", IntegerType).add("b", IntegerType) - ).toAttributes + ) ) val query = TestRelation( new StructType().add( "col", new StructType().add("x", IntegerType).add("y", IntegerType) - ).toAttributes + ) ) withClue("byName") { @@ -1134,11 +1144,11 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { protected def testResolvedOverwriteByExpression(): Unit = { val table = TestRelation(StructType(Seq( StructField("x", DoubleType, nullable = false), - StructField("y", DoubleType))).toAttributes) + StructField("y", DoubleType)))) val query = TestRelation(StructType(Seq( StructField("a", DoubleType, nullable = false), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) val a = query.output.head val b = query.output.last @@ -1162,11 +1172,11 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { protected def testNotResolvedOverwriteByExpression(): Unit = { val table = TestRelation(StructType(Seq( StructField("x", DoubleType, nullable = false), - StructField("y", DoubleType))).toAttributes) + StructField("y", DoubleType)))) val query = TestRelation(StructType(Seq( StructField("a", DoubleType, nullable = false), - StructField("b", DoubleType))).toAttributes) + StructField("b", DoubleType)))) // the write is resolved (checked above). this test plan is not because of the expression. val parsedPlan = OverwriteByExpression.byPosition(table, query, @@ -1181,7 +1191,7 @@ abstract class V2WriteAnalysisSuiteBase extends AnalysisTest { val tableAcceptAnySchema = TestRelationAcceptAnySchema(StructType(Seq( StructField("x", DoubleType, nullable = false), - StructField("y", DoubleType))).toAttributes) + StructField("y", DoubleType)))) val parsedPlan2 = OverwriteByExpression.byPosition(tableAcceptAnySchema, query, LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index c962f953696..f4106e65e7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -305,7 +306,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { - to.resolveAndBind(from.schema.toAttributes) + to.resolveAndBind(toAttributes(from.schema)) } } @@ -314,7 +315,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { - intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes)) + intercept[AnalysisException](to.resolveAndBind(toAttributes(from.schema))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 79417c4ca1f..9d2051b01d6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.CodegenInterpretedPlanTest import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -670,7 +671,7 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes ClosureCleaner.clean((s: String) => encoder.getClass.getName) val row = encoder.createSerializer().apply(input) - val schema = encoder.schema.toAttributes + val schema = toAttributes(encoder.schema) val boundEncoder = encoder.resolveAndBind() val convertedBack = try boundEncoder.createDeserializer().apply(row) catch { case e: Exception => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala index e09ae776d1c..ddeac6cbb93 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SelectedFieldSuite.scala @@ -58,7 +58,7 @@ class SelectedFieldSuite extends AnalysisTest { MapType(StringType, IntegerType, valueContainsNull = false)) :: Nil)) :: Nil) test("SelectedField should not match an attribute reference") { - val testRelation = LocalRelation(nestedComplex.toAttributes) + val testRelation = LocalRelation(nestedComplex) assertResult(None)(unapplySelect("col1", testRelation)) assertResult(None)(unapplySelect("col1 as foo", testRelation)) assertResult(None)(unapplySelect("col2", testRelation)) @@ -553,7 +553,7 @@ class SelectedFieldSuite extends AnalysisTest { } private def assertSelect(expr: String, expected: StructField, inputSchema: StructType): Unit = { - val relation = LocalRelation(inputSchema.toAttributes) + val relation = LocalRelation(inputSchema) unapplySelect(expr, relation) match { case Some(field) => assertResult(expected)(field)(expr) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9c2a9dec1fa..76caaabd942 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNodeTag +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -3407,7 +3408,7 @@ class Dataset[T] private[sql]( sparkSession, MapInPandas( func, - func.dataType.asInstanceOf[StructType].toAttributes, + toAttributes(func.dataType.asInstanceOf[StructType]), logicalPlan, isBarrier)) } @@ -3422,7 +3423,7 @@ class Dataset[T] private[sql]( sparkSession, PythonMapInArrow( func, - func.dataType.asInstanceOf[StructType].toAttributes, + toAttributes(func.dataType.asInstanceOf[StructType]), logicalPlan, isBarrier)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 29138b5bf58..11327cdf7d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.QueryExecution @@ -563,7 +564,7 @@ class RelationalGroupedDataset protected[sql]( val project = df.sparkSession.sessionState.executePlan( Project(groupingNamedExpressions ++ child.output, child)).analyzed val groupingAttributes = project.output.take(groupingNamedExpressions.length) - val output = expr.dataType.asInstanceOf[StructType].toAttributes + val output = toAttributes(expr.dataType.asInstanceOf[StructType]) val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) Dataset.ofRows(df.sparkSession, plan) @@ -608,7 +609,7 @@ class RelationalGroupedDataset protected[sql]( val right = r.df.sparkSession.sessionState.executePlan( Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)).analyzed - val output = expr.dataType.asInstanceOf[StructType].toAttributes + val output = toAttributes(expr.dataType.asInstanceOf[StructType]) val plan = FlatMapCoGroupsInPandas( leftGroupingNamedExpressions.length, rightGroupingNamedExpressions.length, expr, output, left, right) @@ -646,7 +647,7 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } val groupingAttrs = groupingNamedExpressions.map(_.toAttribute) - val outputAttrs = outputStructType.toAttributes + val outputAttrs = toAttributes(outputStructType) val plan = FlatMapGroupsInPandasWithState( func, groupingAttrs, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 2a1c2474bc6..87093c5c906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParame import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.ExternalCommandRunner import org.apache.spark.sql.errors.QueryCompilationErrors @@ -299,7 +300,7 @@ class SparkSession private( */ def emptyDataset[T: Encoder]: Dataset[T] = { val encoder = implicitly[Encoder[T]] - new Dataset(self, LocalRelation(encoder.schema.toAttributes), encoder) + new Dataset(self, LocalRelation(encoder.schema), encoder) } /** @@ -319,7 +320,7 @@ class SparkSession private( */ def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = withActive { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributeSeq = schema.toAttributes + val attributeSeq = toAttributes(schema) Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) } @@ -390,7 +391,7 @@ class SparkSession private( @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = withActive { val replaced = CharVarcharUtils.failIfHasCharVarchar(schema).asInstanceOf[StructType] - Dataset.ofRows(self, LocalRelation.fromExternalRows(replaced.toAttributes, rows.asScala.toSeq)) + Dataset.ofRows(self, LocalRelation.fromExternalRows(toAttributes(replaced), rows.asScala.toSeq)) } /** @@ -479,7 +480,7 @@ class SparkSession private( def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { val enc = encoderFor[T] val toRow = enc.createSerializer() - val attributes = enc.schema.toAttributes + val attributes = toAttributes(enc.schema) val encoded = data.map(d => toRow(d).copy()) val plan = new LocalRelation(attributes, encoded) Dataset[T](self, plan) @@ -565,7 +566,7 @@ class SparkSession private( // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD( - schema.toAttributes, + toAttributes(schema), catalystRows, isStreaming = isStreaming)(self) Dataset.ofRows(self, logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index c7392af360a..c81ef403980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ @@ -566,11 +567,11 @@ case class HashAggregateExec( ctx.currentVars = null ctx.INPUT_ROW = row val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, - groupingKeySchema.toAttributes.zipWithIndex + toAttributes(groupingKeySchema).zipWithIndex .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } ) val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, - bufferSchema.toAttributes.zipWithIndex.map { case (attr, i) => + toAttributes(bufferSchema).zipWithIndex.map { case (attr, i) => BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index c1e225200f7..e517376bc5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -373,7 +374,7 @@ case class ScalaUDAF( override val aggBufferSchema: StructType = udaf.bufferSchema - override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + override val aggBufferAttributes: Seq[AttributeReference] = toAttributes(aggBufferSchema) // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. @@ -389,7 +390,7 @@ case class ScalaUDAF( } private lazy val inputProjection = { - val inputAttributes = childrenSchema.toAttributes + val inputAttributes = toAttributes(childrenSchema) log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") MutableProjection.create(children, inputAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 5d89065e04d..59d931bbe48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} @@ -373,7 +374,7 @@ private[sql] object ArrowConverters extends Logging { schemaString: String, session: SparkSession): DataFrame = { val schema = DataType.fromJson(schemaString).asInstanceOf[StructType] - val attrs = schema.toAttributes + val attrs = toAttributes(schema) val batchesInDriver = arrowBatches.toArray val shouldUseRDD = session.sessionState.conf .arrowLocalRelationThreshold < batchesInDriver.map(_.length.toLong).sum diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7ad9964a9ec..72502a7626b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -21,6 +21,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.IgnoreCachedData +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -41,7 +42,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) val schema = StructType(Array( StructField("key", StringType, nullable = false), StructField("value", StringType, nullable = false))) - schema.toAttributes + toAttributes(schema) } private val (_output, runFunc): (Seq[Attribute], SparkSession => Seq[Row]) = kv match { @@ -130,7 +131,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) StructField("value", StringType, nullable = false), StructField("meaning", StringType, nullable = false), StructField("Since version", StringType, nullable = false))) - (schema.toAttributes, runFunc) + (toAttributes(schema), runFunc) // Queries the deprecated "mapred.reduce.tasks" property. case Some((SQLConf.Deprecated.MAPRED_REDUCE_TASKS, None)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index bbe0d3c0c83..a8f7cdb2600 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} @@ -939,8 +940,8 @@ object DDLUtils extends Logging { HiveTableRelation( table, // Hive table columns are always nullable. - table.dataSchema.asNullable.toAttributes, - table.partitionSchema.asNullable.toAttributes) + toAttributes(table.dataSchema.asNullable), + toAttributes(table.partitionSchema.asNullable)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index eb88acd7b0b..b9a7151b4af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResource} import org.apache.spark.sql.catalyst.expressions.{Attribute, ExpressionInfo} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types.{StringType, StructField, StructType} @@ -96,7 +97,7 @@ case class DescribeFunctionCommand( override val output: Seq[Attribute] = { val schema = StructType(Array(StructField("function_desc", StringType, nullable = false))) - schema.toAttributes + toAttributes(schema) } override def run(sparkSession: SparkSession): Seq[Row] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala index 2e71c829115..df16f8161d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormat.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter @@ -137,7 +138,7 @@ trait FileFormat { sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf) new (PartitionedFile => Iterator[InternalRow]) with Serializable { - private val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + private val fullSchema = toAttributes(requiredSchema) ++ toAttributes(partitionSchema) // Using lazy val to avoid serialization private lazy val appendPartitionColumns = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 5a3ad0021ae..83064a8179e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.sources.BaseRelation @@ -91,13 +92,13 @@ object LogicalRelation { // The v1 source may return schema containing char/varchar type. We replace char/varchar // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) - LogicalRelation(relation, schema.toAttributes, None, isStreaming) + LogicalRelation(relation, toAttributes(schema), None, isStreaming) } def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = { // The v1 source may return schema containing char/varchar type. We replace char/varchar // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) - LogicalRelation(relation, schema.toAttributes, Some(table), false) + LogicalRelation(relation, toAttributes(schema), Some(table), false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index ef74036b23b..666ae9b5c6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -22,6 +22,7 @@ import scala.util.control.NonFatal import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.command.LeafRunnableCommand import org.apache.spark.sql.sources.CreatableRelationProvider @@ -47,7 +48,7 @@ case class SaveIntoDataSourceCommand( sparkSession.sqlContext, mode, options, Dataset.ofRows(sparkSession, query)) try { - val logicalRelation = LogicalRelation(relation, relation.schema.toAttributes, None, false) + val logicalRelation = LogicalRelation(relation, toAttributes(relation.schema), None, false) sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) } catch { case NonFatal(_) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index e361e9eb8c0..28f05ca7f8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} @@ -184,8 +185,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // We need to update the data type of the output attributes to use the pruned ones. // so that references to the original relation's output are not broken val nameAttributeMap = output.map(att => (att.name, att)).toMap - requiredSchema - .toAttributes + toAttributes(requiredSchema) .map { case att if nameAttributeMap.contains(att.name) => nameAttributeMap(att.name).withDataType(att.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index a669adb29e7..b7e6f11f67d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -219,7 +220,7 @@ class OrcFileFormat val iter = new RecordReaderIterator[OrcStruct](orcRecordReader) Option(TaskContext.get()).foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val fullSchema = toAttributes(requiredSchema) ++ toAttributes(partitionSchema) val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) val deserializer = new OrcDeserializer(requiredSchema, requestedColIds) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 738fe81ba9f..c131ad2cf31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources._ @@ -323,7 +324,7 @@ class ParquetFileFormat try { readerWithRowIndexes.initialize(split, hadoopAttemptContext) - val fullSchema = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val fullSchema = toAttributes(requiredSchema) ++ toAttributes(partitionSchema) val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) if (partitionSchema.length == 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 0b4df18eb7c..3f235e10c81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Expression, InputFileBlockLength, InputFileBlockStart, InputFileName, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.connector.expressions.{FieldReference, RewritableTransform} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -195,7 +196,7 @@ case class PreprocessTableCreation(catalog: SessionCatalog) extends Rule[Logical c.copy( tableDesc = existingTable, query = Some(TableOutputResolver.resolveOutputColumns( - tableDesc.qualifiedName, existingTable.schema.toAttributes, newQuery, + tableDesc.qualifiedName, toAttributes(existingTable.schema), newQuery, byName = true, conf))) // Here we normalize partition, bucket and sort column names, w.r.t. the case sensitivity diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index a52ea901f02..71e86beefda 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Expression, ExpressionSet} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.{Batch, InputPartition, Scan, Statistics, SupportsReportStatistics} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.PartitionedFileUtil @@ -85,11 +86,11 @@ trait FileScan extends Scan private lazy val (normalizedPartitionFilters, normalizedDataFilters) = { val partitionFilterAttributes = AttributeSet(partitionFilters).map(a => a.name -> a).toMap val normalizedPartitionFilters = ExpressionSet(partitionFilters.map( - QueryPlan.normalizeExpressions(_, fileIndex.partitionSchema.toAttributes + QueryPlan.normalizeExpressions(_, toAttributes(fileIndex.partitionSchema) .map(a => partitionFilterAttributes.getOrElse(a.name, a))))) val dataFiltersAttributes = AttributeSet(dataFilters).map(a => a.name -> a).toMap val normalizedDataFilters = ExpressionSet(dataFilters.map( - QueryPlan.normalizeExpressions(_, dataSchema.toAttributes + QueryPlan.normalizeExpressions(_, toAttributes(dataSchema) .map(a => dataFiltersAttributes.getOrElse(a.name, a))))) (normalizedPartitionFilters, normalizedDataFilters) } @@ -132,7 +133,7 @@ trait FileScan extends Scan protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(partitionFilters, dataFilters) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) - val partitionAttributes = fileIndex.partitionSchema.toAttributes + val partitionAttributes = toAttributes(fileIndex.partitionSchema) val attributeMap = partitionAttributes.map(a => normalizeName(a.name) -> a).toMap val readPartitionAttributes = readPartitionSchema.map { readField => attributeMap.getOrElse(normalizeName(readField.name), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index b54f05bec12..e5f064fcf6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.connector.write.{BatchWrite, LogicalWriteInfo, Write} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -119,7 +120,7 @@ trait FileWrite extends Write { // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = prepareWrite(sparkSession.sessionState.conf, job, caseInsensitiveOptions, schema) - val allColumns = schema.toAttributes + val allColumns = toAttributes(schema) val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics val serializableHadoopConf = new SerializableConfiguration(hadoopConf) val statsTracker = new BasicWriteJobStatsTracker(serializableHadoopConf, metrics) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala index 7bca98e54ef..173348f907f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PartitionReaderWithPartitionValues.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.PartitionReader import org.apache.spark.sql.types.StructType @@ -31,7 +32,7 @@ class PartitionReaderWithPartitionValues( readDataSchema: StructType, partitionSchema: StructType, partitionValues: InternalRow) extends PartitionReader[InternalRow] { - private val fullSchema = readDataSchema.toAttributes ++ partitionSchema.toAttributes + private val fullSchema = toAttributes(readDataSchema) ++ toAttributes(partitionSchema) private val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, fullSchema) // Note that we have to apply the converter even though `file.partitionValues` is empty. // This is because the converter is also responsible for converting safe `InternalRow`s into diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index fe19ac552f9..e78b1359021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, SchemaPruning} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -212,7 +213,7 @@ object PushDownUtils { relation: DataSourceV2Relation): Seq[AttributeReference] = { val nameToAttr = Utils.toMap(relation.output.map(_.name), relation.output) val cleaned = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) - cleaned.toAttributes.map { + toAttributes(cleaned).map { // we have to keep the attribute id during transformation a => a.withExprId(nameToAttr(a.name).exprId) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index e58fe7844ab..ef3982ff908 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.{PhysicalOperation, ScanOperation} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LimitAndOffset, LocalLimit, LogicalPlan, Offset, OffsetAndLimit, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, CountStar, Max, Min, Sum} import org.apache.spark.sql.connector.expressions.filter.Predicate @@ -330,7 +331,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // DataSourceV2ScanRelation output columns. All the other columns are not // included in the output. val scan = holder.builder.build() - val realOutput = scan.readSchema().toAttributes + val realOutput = toAttributes(scan.readSchema()) assert(realOutput.length == holder.output.length, "The data source returns unexpected number of columns") val wrappedScan = getWrappedScan(scan, holder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 977dbe3f4ef..0870867bf21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{ExposesMetadataColumns, LeafNode, LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.streaming.SparkDataStream import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.LeafExecNode @@ -32,7 +33,7 @@ import org.apache.spark.sql.execution.datasources.{DataSource, FileFormat} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { StreamingRelation( - dataSource, dataSource.sourceInfo.name, dataSource.sourceInfo.schema.toAttributes) + dataSource, dataSource.sourceInfo.name, toAttributes(dataSource.sourceInfo.schema)) } } @@ -119,13 +120,13 @@ case class StreamingRelationExec( object StreamingExecutionRelation { def apply(source: Source, session: SparkSession): StreamingExecutionRelation = { - StreamingExecutionRelation(source, source.schema.toAttributes, None)(session) + StreamingExecutionRelation(source, toAttributes(source.schema), None)(session) } def apply( source: Source, session: SparkSession, catalogTable: CatalogTable): StreamingExecutionRelation = { - StreamingExecutionRelation(source, source.schema.toAttributes, Some(catalogTable))(session) + StreamingExecutionRelation(source, toAttributes(source.schema), Some(catalogTable))(session) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 1d377350253..34076f26fe8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} @@ -55,7 +56,7 @@ object MemoryStream { */ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream { val encoder = encoderFor[A] - protected val attributes = encoder.schema.toAttributes + protected val attributes = toAttributes(encoder.schema) protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala index 5cb11b9280c..471fd7feedc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleStreamingWrite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.write.{PhysicalWriteInfo, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -64,7 +65,7 @@ class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap) println(printMessage) println("-------------------------------------------") // scalastyle:off println - Dataset.ofRows(spark, LocalRelation(schema.toAttributes, rows)) + Dataset.ofRows(spark, LocalRelation(toAttributes(schema), rows)) .show(numRowsToShow, isTruncated) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 6ebf2cf539e..bbbe28ec7ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.catalog.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.connector.write.{DataWriter, LogicalWriteInfo, PhysicalWriteInfo, SupportsTruncate, Write, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite} @@ -81,7 +82,7 @@ class ForeachWrite[T]( val rowConverter: InternalRow => T = converter match { case Left(enc) => val boundEnc = enc.resolveAndBind( - inputSchema.toAttributes, + toAttributes(inputSchema), SparkSession.getActiveSession.get.sessionState.analyzer) boundEnc.createDeserializer() case Right(func) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala index d396e71177b..b68c08b3ea5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithStateExecHelper.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.ObjectOperator import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.types._ @@ -110,7 +111,7 @@ object FlatMapGroupsWithStateExecHelper { private lazy val stateSerializerFunc = ObjectOperator.serializeObjectToRow(stateSerializerExprs) private lazy val stateDeserializerFunc = { - ObjectOperator.deserializeRowToObject(stateDeserializerExpr, stateSchema.toAttributes) + ObjectOperator.deserializeRowToObject(stateDeserializerExpr, toAttributes(stateSchema)) } private lazy val stateDataForGets = StateData() @@ -154,7 +155,7 @@ object FlatMapGroupsWithStateExecHelper { AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes + val encSchemaAttribs = toAttributes(stateEncoder.schema) if (shouldStoreTimestamp) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 9e8356d3fdb..9e96da98eb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -25,6 +25,7 @@ import org.apache.spark.TaskContext import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, JoinedRow, Literal, SafeProjection, SpecificInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ @@ -380,7 +381,7 @@ class SymmetricHashJoinStateManager( private val keySchema = StructType( joinKeys.zipWithIndex.map { case (k, i) => StructField(s"field$i", k.dataType, k.nullable) }) - private val keyAttributes = keySchema.toAttributes + private val keyAttributes = toAttributes(keySchema) private val keyToNumValues = new KeyToNumValuesStore() private val keyWithIndexToValue = new KeyWithIndexToValueStore(stateFormatVersion) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 76c89bfa4a3..796710a3567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, LocalRelation, LogicalPlan, OptionList, RecoverPartitions, ShowFunctions, ShowNamespaces, ShowTables, UnresolvedTableSpec, View} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogPlugin, CatalogV2Util, FunctionCatalog, Identifier, SupportsNamespaces, Table => V2Table, TableCatalog, V1Table} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{CatalogHelper, MultipartIdentifierHelper, NamespaceHelper, TransformHelper} import org.apache.spark.sql.errors.QueryCompilationErrors @@ -936,7 +937,7 @@ private[sql] object CatalogImpl { val enc = ExpressionEncoder[T]() val toRow = enc.createSerializer() val encoded = data.map(d => toRow(d).copy()) - val plan = new LocalRelation(enc.schema.toAttributes, encoded) + val plan = new LocalRelation(DataTypeUtils.toAttributes(enc.schema), encoded) val queryExecution = sparkSession.sessionState.executePlan(plan) new Dataset[T](queryExecution, enc) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 13f7695947e..e3cd7379dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -185,7 +186,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo sparkSession, StreamingRelationV2( Some(provider), source, table, dsOptions, - table.columns.asSchema.toAttributes, None, None, v1Relation)) + toAttributes(table.columns.asSchema), None, None, v1Relation)) // fallback to v1 // TODO (SPARK-27483): we should move this fallback logic to an analyzer rule. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ac7915921d3..b7bafeca546 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.logical.{CreateTable, OptionList, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.{Identifier, SupportsWrite, Table, TableCatalog, TableProvider, V1Table, V2TableWithV1Fallback} import org.apache.spark.sql.connector.catalog.TableCapability._ @@ -347,7 +348,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { throw QueryCompilationErrors.queryNameNotSpecifiedForMemorySinkError() } val sink = new MemorySink() - val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes)) + val resultDf = Dataset.ofRows(df.sparkSession, + MemoryPlan(sink, DataTypeUtils.toAttributes(df.schema))) val recoverFromCheckpoint = outputMode == OutputMode.Complete() val query = startQuery(sink, extraOptions, recoverFromCheckpoint = recoverFromCheckpoint, catalogTable = catalogTable) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index 5f2e0b28aec..d7a8225a7d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.catalog.{Table, TableCapability} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.datasources.DataSource @@ -41,7 +42,7 @@ class TableCapabilityCheckSuite extends AnalysisTest with SharedSparkSession { "fake", table, CaseInsensitiveStringMap.empty(), - TableCapabilityCheckSuite.schema.toAttributes, + toAttributes(TableCapabilityCheckSuite.schema), None, None, v1Relation) @@ -207,7 +208,7 @@ private object TableCapabilityCheckSuite { private case object TestRelation extends LeafNode with NamedRelation { override def name: String = "source_relation" - override def output: Seq[AttributeReference] = TableCapabilityCheckSuite.schema.toAttributes + override def output: Seq[AttributeReference] = toAttributes(TableCapabilityCheckSuite.schema) } private case class CapabilityTable(_capabilities: TableCapability*) extends Table { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala index 0b51e6de5ed..6a69691bea8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V2CommandsCaseSensitivitySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.connector import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, CreateTablePartitioningValidationSuite, ResolvedTable, TestRelation2, TestTable2, UnresolvedFieldName, UnresolvedFieldPosition, UnresolvedIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{AddColumns, AlterColumn, AlterTableCommand, CreateTableAsSelect, DropColumns, LogicalPlan, OptionList, QualifiedColType, RenameColumn, ReplaceColumns, ReplaceTableAsSelect, UnresolvedTableSpec} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.catalog.Identifier import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition import org.apache.spark.sql.connector.expressions.Expressions @@ -41,7 +42,7 @@ class V2CommandsCaseSensitivitySuite catalog, Identifier.of(Array(), "table_name"), TestTable2, - schema.toAttributes) + toAttributes(schema)) override protected def extendedAnalysisRules: Seq[Rule[LogicalPlan]] = { Seq(PreprocessTableCreation(spark.sessionState.catalog)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 7af61bd2f3a..8b27a98e2b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} class GroupedIteratorSuite extends SparkFunSuite { @@ -32,7 +33,7 @@ class GroupedIteratorSuite extends SparkFunSuite { val fromRow = encoder.createDeserializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq($"i".int.at(0)), schema.toAttributes) + Seq($"i".int.at(0)), toAttributes(schema)) val result = grouped.map { case (key, data) => @@ -59,7 +60,7 @@ class GroupedIteratorSuite extends SparkFunSuite { Row(3, 2L, "e")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq($"i".int.at(0), $"l".long.at(1)), schema.toAttributes) + Seq($"i".int.at(0), $"l".long.at(1)), toAttributes(schema)) val result = grouped.map { case (key, data) => @@ -80,7 +81,7 @@ class GroupedIteratorSuite extends SparkFunSuite { val toRow = encoder.createSerializer() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(toRow), - Seq($"i".int.at(0)), schema.toAttributes) + Seq($"i".int.at(0)), toAttributes(schema)) assert(grouped.length == 2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 3a0bd35cb70..b8f3ea3c6f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.{CodegenSupport, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.adaptive.{DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite} import org.apache.spark.sql.functions._ @@ -61,7 +62,7 @@ abstract class DebuggingSuiteBase extends SharedSparkSession { case class DummyCodeGeneratorPlan(useInnerClass: Boolean) extends CodegenSupport with LeafExecNode { - override def output: Seq[Attribute] = StructType.fromDDL("d int").toAttributes + override def output: Seq[Attribute] = toAttributes(StructType.fromDDL("d int")) override def inputRDDs(): Seq[RDD[InternalRow]] = Seq(spark.sparkContext.emptyRDD[InternalRow]) override protected def doExecute(): RDD[InternalRow] = sys.error("Not used") override protected def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 014840d758c..54752ba1684 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -218,7 +219,7 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { implicit val schema = new StructType().add(new StructField("value", IntegerType)) val sink = new MemorySink val addBatch = addBatchFunc(sink, false) _ - val plan = new MemoryPlan(sink, schema.toAttributes) + val plan = new MemoryPlan(sink, DataTypeUtils.toAttributes(schema)) // Before adding data, check output checkAnswer(sink.allData, Seq.empty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala index ec613bc00d8..6e699e0b645 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSessionsIteratorSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal, MutableProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.aggregate.MergingSessionsIterator import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.test.SharedSparkSession @@ -32,7 +33,7 @@ class MergingSessionsIteratorSuite extends SharedSparkSession { private val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) .add("session", new StructType().add("start", LongType).add("end", LongType)) .add("count", LongType) - private val rowAttributes = rowSchema.toAttributes + private val rowAttributes = toAttributes(rowSchema) private val keysWithSessionAttributes = rowAttributes.filter { attr => List("key1", "key2", "session").contains(attr.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala index 81f1a3f785c..63d3c0bdf4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MergingSortWithSessionWindowStateIteratorSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.streaming.state.{HDFSBackedStateStoreProvider, RocksDBStateStoreProvider, StateStore, StateStoreConf, StateStoreId, StateStoreProviderId, StreamingSessionWindowStateManager} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest @@ -35,7 +36,7 @@ class MergingSortWithSessionWindowStateIteratorSuite extends StreamTest with Bef private val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) .add("session", new StructType().add("start", LongType).add("end", LongType)) .add("value", LongType) - private val rowAttributes = rowSchema.toAttributes + private val rowAttributes = toAttributes(rowSchema) private val keysWithoutSessionAttributes = rowAttributes.filter { attr => List("key1", "key2").contains(attr.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala index 34c4939cbc1..ae6dc536bf7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/UpdatingSessionsIteratorSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.aggregate.UpdatingSessionsIterator import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -34,7 +35,7 @@ class UpdatingSessionsIteratorSuite extends SharedSparkSession { private val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) .add("session", new StructType().add("start", LongType).add("end", LongType)) .add("aggVal1", LongType).add("aggVal2", DoubleType) - private val rowAttributes = rowSchema.toAttributes + private val rowAttributes = toAttributes(rowSchema) private val noKeyRowAttributes = rowAttributes.filterNot { attr => Seq("key1", "key2").contains(attr.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala index daacdfd58c7..6685b140960 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingAggregationStateManagerSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -31,7 +32,7 @@ class StreamingAggregationStateManagerSuite extends StreamTest { val testOutputSchema: StructType = StructType( testKeys.map(createIntegerField) ++ testValues.map(createIntegerField)) - val testOutputAttributes: Seq[Attribute] = testOutputSchema.toAttributes + val testOutputAttributes: Seq[Attribute] = toAttributes(testOutputSchema) val testKeyAttributes: Seq[Attribute] = testOutputAttributes.filter { p => testKeys.contains(p.name) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala index 096c3bb56f7..e3a3a382bf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StreamingSessionWindowStateManagerSuite.scala @@ -24,6 +24,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest @@ -34,7 +35,7 @@ class StreamingSessionWindowStateManagerSuite extends StreamTest with BeforeAndA private val rowSchema = new StructType().add("key1", StringType).add("key2", IntegerType) .add("session", new StructType().add("start", LongType).add("end", LongType)) .add("value", LongType) - private val rowAttributes = rowSchema.toAttributes + private val rowAttributes = toAttributes(rowSchema) private val keysWithoutSessionAttributes = rowAttributes.filter { attr => List("key1", "key2").contains(attr.name) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index 4e48dc119c3..b0abcbbe4d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, GenericInternalRow, LessThanOrEqual, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide @@ -239,7 +240,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter val inputValueSchema = new StructType() .add(StructField("time", IntegerType, metadata = watermarkMetadata)) .add(StructField("value", BooleanType)) - val inputValueAttribs = inputValueSchema.toAttributes + val inputValueAttribs = toAttributes(inputValueSchema) val inputValueAttribWithWatermark = inputValueAttribs(0) val joinKeyExprs = Seq[Expression](Literal(false), inputValueAttribWithWatermark, Literal(10.0)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala index fdc633f3556..0e3ba6b79eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLAppStatusListenerSuite.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.connector.{CSVDataWriter, CSVDataWriterFactory, RangeInputPartition, SimpleScanBuilder, SimpleWritableDataSource, TestLocalScanTable} import org.apache.spark.sql.connector.catalog.Table @@ -846,7 +847,7 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes val oldCount = statusStore.executionsList().size val schema = new StructType().add("i", "int").add("j", "int") - val physicalPlan = BatchScanExec(schema.toAttributes, new CustomMetricScanBuilder(), Seq.empty, + val physicalPlan = BatchScanExec(toAttributes(schema), new CustomMetricScanBuilder(), Seq.empty, table = new TestLocalScanTable("fake")) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan @@ -885,7 +886,7 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes val oldCount = statusStore.executionsList().size val schema = new StructType().add("i", "int").add("j", "int") - val physicalPlan = BatchScanExec(schema.toAttributes, new CustomDriverMetricScanBuilder(), + val physicalPlan = BatchScanExec(toAttributes(schema), new CustomDriverMetricScanBuilder(), Seq.empty, table = new TestLocalScanTable("fake")) val dummyQueryExecution = new QueryExecution(spark, LocalRelation()) { override lazy val sparkPlan = physicalPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 776bcb3211f..746f289c393 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.{FileSourceScanExec, SortExec, SparkPlan} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper} import org.apache.spark.sql.execution.datasources.BucketingUtils @@ -125,7 +126,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils with Adapti // Limit: bucket pruning only works when the bucket column has one and only one column assert(bucketColumnNames.length == 1) val bucketColumnIndex = bucketedDataFrame.schema.fieldIndex(bucketColumnNames.head) - val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) + val bucketColumn = DataTypeUtils.toAttribute(bucketedDataFrame.schema(bucketColumnIndex)) // Filter could hide the bug in bucket pruning. Thus, skipping all the filters val plan = bucketedDataFrame.filter(filterCondition).queryExecution.executedPlan diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 0ee44a098f7..c97979a57a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.{Range, RepartitionByExpression} import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2} +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan} import org.apache.spark.sql.execution.command.ExplainCommand @@ -110,7 +111,7 @@ class StreamSuite extends StreamTest { test("StreamingExecutionRelation.computeStats") { val memoryStream = MemoryStream[Int] val executionRelation = StreamingExecutionRelation( - memoryStream, memoryStream.encoder.schema.toAttributes, None)( + memoryStream, toAttributes(memoryStream.encoder.schema), None)( memoryStream.sqlContext.sparkSession) assert(executionRelation.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index b889ac18974..4a6325eb060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.Complete +import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.connector.read.streaming.{Offset => OffsetV2} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec @@ -1292,10 +1293,10 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi override def getBatch(start: Option[Offset], end: Offset): DataFrame = { if (batchId == 0) { batchId += 1 - Dataset.ofRows(spark, LocalRelation(schema.toAttributes, Nil, isStreaming = true)) + Dataset.ofRows(spark, LocalRelation(toAttributes(schema), Nil, isStreaming = true)) } else { Dataset.ofRows(spark, - LocalRelation(schema.toAttributes, InternalRow(10) :: Nil, isStreaming = true)) + LocalRelation(toAttributes(schema), InternalRow(10) :: Nil, isStreaming = true)) } } override def schema: StructType = MockSourceProvider.fakeSchema diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index debe1ab734c..3fee7df19ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.{sources, SparkSession} import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, GenericInternalRow, InterpretedProjection, JoinedRow, Literal, Predicate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -70,7 +71,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { SimpleTextRelation.pushedFilters = filters.toSet val fieldTypes = dataSchema.map(_.dataType) - val inputAttributes = dataSchema.toAttributes + val inputAttributes = DataTypeUtils.toAttributes(dataSchema) val outputAttributes = requiredSchema.flatMap { field => inputAttributes.find(_.name == field.name) } @@ -106,7 +107,8 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { }.filter(predicate.eval).map(projection) // Appends partition values - val fullOutput = requiredSchema.toAttributes ++ partitionSchema.toAttributes + val fullOutput = DataTypeUtils.toAttributes(requiredSchema) ++ + DataTypeUtils.toAttributes(partitionSchema) val joinedRow = new JoinedRow() val appendPartitionColumns = GenerateUnsafeProjection.generate(fullOutput, fullOutput) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org