This is an automated email from the ASF dual-hosted git repository. yao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5a5bf04aca8a [SPARK-51874][CORE][SQL] Add TypedConfigBuilder for Scala Enumeration 5a5bf04aca8a is described below commit 5a5bf04aca8a0b600d91b9f71911c090a8fe14e3 Author: Kent Yao <y...@apache.org> AuthorDate: Thu Apr 24 11:03:44 2025 +0800 [SPARK-51874][CORE][SQL] Add TypedConfigBuilder for Scala Enumeration ### What changes were proposed in this pull request? This PR introduces TypedConfigBuilder for Scala Enumeration and leverages it for existing configurations that use Enumeration as parameters. Before this PR, we need to change them from Enumeration to string, string to Enumeration, back and forth... We also need to do upper-case transformation, .checkValues validation one by one. After this PR, those steps are centralized. ### Why are the changes needed? better support for Enumeration-like configurations ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #50674 from yaooqinn/enum. Authored-by: Kent Yao <y...@apache.org> Signed-off-by: Kent Yao <y...@apache.org> --- .../spark/internal/config/ConfigBuilder.scala | 16 +++ .../spark/internal/config/ConfigEntrySuite.scala | 21 ++++ .../sql/catalyst/analysis/CTESubstitution.scala | 2 +- .../catalyst/analysis/resolver/ResolverGuard.scala | 2 +- .../CodeGeneratorWithInterpretedFallback.scala | 2 +- .../sql/catalyst/expressions/ToStringBase.scala | 2 +- .../sql/catalyst/util/ArrayBasedMapBuilder.scala | 4 +- .../org/apache/spark/sql/internal/SQLConf.scala | 125 ++++++++------------- .../apache/spark/sql/avro/AvroDeserializer.scala | 4 +- .../org/apache/spark/sql/avro/AvroOptions.scala | 6 +- .../apache/spark/sql/avro/AvroOutputWriter.scala | 3 +- .../org/apache/spark/sql/avro/AvroSerializer.scala | 2 +- .../apache/spark/sql/execution/HiveResult.scala | 3 +- .../sql/execution/WholeStageCodegenExec.scala | 3 +- .../execution/datasources/DataSourceUtils.scala | 8 +- .../datasources/parquet/ParquetOptions.scala | 9 +- .../datasources/parquet/ParquetWriteSupport.scala | 6 +- .../sql/execution/datasources/csv/CSVSuite.scala | 2 +- .../sql/execution/datasources/json/JsonSuite.scala | 2 +- .../apache/spark/sql/internal/SQLConfSuite.scala | 18 ++- .../hive/execution/HiveCompatibilitySuite.scala | 2 +- .../spark/sql/hive/HiveSchemaInferenceSuite.scala | 2 - 22 files changed, 127 insertions(+), 117 deletions(-) diff --git a/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala index d3e975d1782f..f68ced069505 100644 --- a/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala +++ b/common/utils/src/main/scala/org/apache/spark/internal/config/ConfigBuilder.scala @@ -17,6 +17,7 @@ package org.apache.spark.internal.config +import java.util.Locale import java.util.concurrent.TimeUnit import java.util.regex.PatternSyntaxException @@ -46,6 +47,16 @@ private object ConfigHelpers { } } + def toEnum[E <: Enumeration](s: String, enumClass: E, key: String): enumClass.Value = { + try { + enumClass.withName(s.trim.toUpperCase(Locale.ROOT)) + } catch { + case _: NoSuchElementException => + throw new IllegalArgumentException( + s"$key should be one of ${enumClass.values.mkString(", ")}, but was $s") + } + } + def stringToSeq[T](str: String, converter: String => T): Seq[T] = { SparkStringUtils.stringToSeq(str).map(converter) } @@ -271,6 +282,11 @@ private[spark] case class ConfigBuilder(key: String) { new TypedConfigBuilder(this, v => v) } + def enumConf(e: Enumeration): TypedConfigBuilder[e.Value] = { + checkPrependConfig + new TypedConfigBuilder(this, toEnum(_, e, key)) + } + def timeConf(unit: TimeUnit): TypedConfigBuilder[Long] = { checkPrependConfig new TypedConfigBuilder(this, timeFromString(_, unit), timeToString(_, unit)) diff --git a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala index ae9973508405..5aa542a0b985 100644 --- a/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/config/ConfigEntrySuite.scala @@ -387,4 +387,25 @@ class ConfigEntrySuite extends SparkFunSuite { ConfigBuilder(testKey("oc5")).onCreate(_ => onCreateCalled = true).fallbackConf(fallback) assert(onCreateCalled) } + + + test("SPARK-51874: Add Enum support to ConfigBuilder") { + object MyTestEnum extends Enumeration { + val X, Y, Z = Value + } + val conf = new SparkConf() + val enumConf = ConfigBuilder("spark.test.enum.key") + .enumConf(MyTestEnum) + .createWithDefault(MyTestEnum.X) + assert(conf.get(enumConf) === MyTestEnum.X) + conf.set(enumConf, MyTestEnum.Y) + assert(conf.get(enumConf) === MyTestEnum.Y) + conf.set(enumConf.key, "Z") + assert(conf.get(enumConf) === MyTestEnum.Z) + val e = intercept[IllegalArgumentException] { + conf.set(enumConf.key, "A") + conf.get(enumConf) + } + assert(e.getMessage === s"${enumConf.key} should be one of X, Y, Z, but was A") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala index 5bbe85705ac1..19e58a6e370b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CTESubstitution.scala @@ -84,7 +84,7 @@ object CTESubstitution extends Rule[LogicalPlan] { val cteDefs = ArrayBuffer.empty[CTERelationDef] val (substituted, firstSubstituted) = - LegacyBehaviorPolicy.withName(conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY)) match { + conf.getConf(LEGACY_CTE_PRECEDENCE_POLICY) match { case LegacyBehaviorPolicy.EXCEPTION => assertNoNameConflictsInCTE(plan) traverseAndSubstituteCTE(plan, forceInline, Seq.empty, cteDefs, None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala index 238fed31741b..334ffcdf2f63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala @@ -426,7 +426,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends SQLConfHelper { Some("hiveCaseSensitiveInferenceMode") } else if (conf.getConf(SQLConf.LEGACY_INLINE_CTE_IN_COMMANDS)) { Some("legacyInlineCTEInCommands") - } else if (LegacyBehaviorPolicy.withName(conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY)) != + } else if (conf.getConf(SQLConf.LEGACY_CTE_PRECEDENCE_POLICY) != LegacyBehaviorPolicy.CORRECTED) { Some("legacyCTEPrecedencePolicy") } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala index 62a1afecfd7f..4a074bb3039b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -38,7 +38,7 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging { def createObject(in: IN): OUT = { // We are allowed to choose codegen-only or no-codegen modes if under tests. - val fallbackMode = CodegenObjectFactoryMode.withName(SQLConf.get.codegenFactoryMode) + val fallbackMode = SQLConf.get.codegenFactoryMode fallbackMode match { case CodegenObjectFactoryMode.CODEGEN_ONLY => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index 6cfcde5f52da..2e649763a9ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -457,7 +457,7 @@ trait ToStringBase { self: UnaryExpression with TimeZoneAwareExpression => object ToStringBase { def getBinaryFormatter: BinaryFormatter = { val style = SQLConf.get.getConf(SQLConf.BINARY_OUTPUT_STYLE) - style.map(BinaryOutputStyle.withName) match { + style match { case Some(BinaryOutputStyle.UTF8) => (array: Array[Byte]) => UTF8String.fromBytes(array) case Some(BinaryOutputStyle.BASIC) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index 136e8824569e..25d0f0325520 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -84,9 +84,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria keys.append(keyNormalized) values.append(value) } else { - if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) { + if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION) { throw QueryExecutionErrors.duplicateMapKeyFoundError(key) - } else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { + } else if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.LAST_WIN) { // Overwrite the previous value, as the policy is last wins. values(index) = value } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e4766ee75f27..d5d216923587 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1209,10 +1209,8 @@ object SQLConf { "Unix epoch. TIMESTAMP_MILLIS is also standard, but with millisecond precision, which " + "means Spark has to truncate the microsecond portion of its timestamp value.") .version("2.3.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(ParquetOutputTimestampType.values.map(_.toString)) - .createWithDefault(ParquetOutputTimestampType.INT96.toString) + .enumConf(ParquetOutputTimestampType) + .createWithDefault(ParquetOutputTimestampType.INT96) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + @@ -1551,10 +1549,8 @@ object SQLConf { "attempt to write it to the table properties) and NEVER_INFER (the default mode-- fallback " + "to using the case-insensitive metastore schema instead of inferring).") .version("2.1.1") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(HiveCaseSensitiveInferenceMode.values.map(_.toString)) - .createWithDefault(HiveCaseSensitiveInferenceMode.NEVER_INFER.toString) + .enumConf(HiveCaseSensitiveInferenceMode) + .createWithDefault(HiveCaseSensitiveInferenceMode.NEVER_INFER) val HIVE_TABLE_PROPERTY_LENGTH_THRESHOLD = buildConf("spark.sql.hive.tablePropertyLengthThreshold") @@ -1721,9 +1717,7 @@ object SQLConf { .doc("The output style used display binary data. Valid values are 'UTF-8', " + "'BASIC', 'BASE64', 'HEX', and 'HEX_DISCRETE'.") .version("4.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(BinaryOutputStyle.values.map(_.toString)) + .enumConf(BinaryOutputStyle) .createOptional val PARTITION_COLUMN_TYPE_INFERENCE = @@ -2058,6 +2052,7 @@ object SQLConf { .createWithDefault(100) val CODEGEN_FACTORY_MODE = buildConf("spark.sql.codegen.factoryMode") + .internal() .doc("This config determines the fallback behavior of several codegen generators " + "during tests. `FALLBACK` means trying codegen first and then falling back to " + "interpreted if any compile error happens. Disabling fallback if `CODEGEN_ONLY`. " + @@ -2065,10 +2060,8 @@ object SQLConf { "this configuration is only for the internal usage, and NOT supposed to be set by " + "end users.") .version("2.4.0") - .internal() - .stringConf - .checkValues(CodegenObjectFactoryMode.values.map(_.toString)) - .createWithDefault(CodegenObjectFactoryMode.FALLBACK.toString) + .enumConf(CodegenObjectFactoryMode) + .createWithDefault(CodegenObjectFactoryMode.FALLBACK) val CODEGEN_FALLBACK = buildConf("spark.sql.codegen.fallback") .internal() @@ -3949,10 +3942,8 @@ object SQLConf { "dataframe.write.option(\"partitionOverwriteMode\", \"dynamic\").save(path)." ) .version("2.3.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(PartitionOverwriteMode.values.map(_.toString)) - .createWithDefault(PartitionOverwriteMode.STATIC.toString) + .enumConf(PartitionOverwriteMode) + .createWithDefault(PartitionOverwriteMode.STATIC) object StoreAssignmentPolicy extends Enumeration { val ANSI, LEGACY, STRICT = Value @@ -3974,10 +3965,8 @@ object SQLConf { "not allowed." ) .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(StoreAssignmentPolicy.values.map(_.toString)) - .createWithDefault(StoreAssignmentPolicy.ANSI.toString) + .enumConf(StoreAssignmentPolicy) + .createWithDefault(StoreAssignmentPolicy.ANSI) val ANSI_ENABLED = buildConf(SqlApiConfHelper.ANSI_ENABLED_KEY) .doc("When true, Spark SQL uses an ANSI compliant dialect instead of being Hive compliant. " + @@ -4630,10 +4619,8 @@ object SQLConf { "Before the 3.4.0 release, Spark only supports the TIMESTAMP WITH " + "LOCAL TIME ZONE type.") .version("3.4.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(TimestampTypes.values.map(_.toString)) - .createWithDefault(TimestampTypes.TIMESTAMP_LTZ.toString) + .enumConf(TimestampTypes) + .createWithDefault(TimestampTypes.TIMESTAMP_LTZ) val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") .doc("If the configuration property is set to true, java.time.Instant and " + @@ -4708,10 +4695,8 @@ object SQLConf { "fails if duplicated map keys are detected. When LAST_WIN, the map key that is inserted " + "at last takes precedence.") .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(MapKeyDedupPolicy.values.map(_.toString)) - .createWithDefault(MapKeyDedupPolicy.EXCEPTION.toString) + .enumConf(MapKeyDedupPolicy) + .createWithDefault(MapKeyDedupPolicy.EXCEPTION) val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.doLooseUpcast") .internal() @@ -4727,10 +4712,8 @@ object SQLConf { "The default is CORRECTED, inner CTE definitions take precedence. This config " + "will be removed in future versions and CORRECTED will be the only behavior.") .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val CTE_RECURSION_LEVEL_LIMIT = buildConf("spark.sql.cteRecursionLevelLimit") .doc("Maximum level of recursion that is allowed while executing a recursive CTE definition." + @@ -4765,10 +4748,8 @@ object SQLConf { "When set to EXCEPTION, RuntimeException is thrown when we will get different " + "results. The default is CORRECTED.") .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val LEGACY_ARRAY_EXISTS_FOLLOWS_THREE_VALUED_LOGIC = buildConf("spark.sql.legacy.followThreeValuedLogicInArrayExists") @@ -5056,10 +5037,8 @@ object SQLConf { "When EXCEPTION, Spark will fail the writing if it sees ancient " + "timestamps that are ambiguous between the two calendars.") .version("3.1.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val PARQUET_REBASE_MODE_IN_WRITE = buildConf("spark.sql.parquet.datetimeRebaseModeInWrite") @@ -5073,10 +5052,8 @@ object SQLConf { "TIMESTAMP_MILLIS, TIMESTAMP_MICROS. The INT96 type has the separate config: " + s"${PARQUET_INT96_REBASE_MODE_IN_WRITE.key}.") .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val PARQUET_INT96_REBASE_MODE_IN_READ = buildConf("spark.sql.parquet.int96RebaseModeInRead") @@ -5088,10 +5065,8 @@ object SQLConf { "timestamps that are ambiguous between the two calendars. This config is only effective " + "if the writer info (like Spark, Hive) of the Parquet files is unknown.") .version("3.1.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val PARQUET_REBASE_MODE_IN_READ = buildConf("spark.sql.parquet.datetimeRebaseModeInRead") @@ -5107,10 +5082,8 @@ object SQLConf { s"${PARQUET_INT96_REBASE_MODE_IN_READ.key}.") .version("3.0.0") .withAlternative("spark.sql.legacy.parquet.datetimeRebaseModeInRead") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val AVRO_REBASE_MODE_IN_WRITE = buildConf("spark.sql.avro.datetimeRebaseModeInWrite") @@ -5121,10 +5094,8 @@ object SQLConf { "When EXCEPTION, Spark will fail the writing if it sees " + "ancient dates/timestamps that are ambiguous between the two calendars.") .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val AVRO_REBASE_MODE_IN_READ = buildConf("spark.sql.avro.datetimeRebaseModeInRead") @@ -5136,10 +5107,8 @@ object SQLConf { "ancient dates/timestamps that are ambiguous between the two calendars. This config is " + "only effective if the writer info (like Spark, Hive) of the Avro files is unknown.") .version("3.0.0") - .stringConf - .transform(_.toUpperCase(Locale.ROOT)) - .checkValues(LegacyBehaviorPolicy.values.map(_.toString)) - .createWithDefault(LegacyBehaviorPolicy.CORRECTED.toString) + .enumConf(LegacyBehaviorPolicy) + .createWithDefault(LegacyBehaviorPolicy.CORRECTED) val SCRIPT_TRANSFORMATION_EXIT_TIMEOUT = buildConf("spark.sql.scriptTransformation.exitTimeoutInSeconds") @@ -5480,9 +5449,8 @@ object SQLConf { "STANDARD includes an additional JSON field `message`. This configuration property " + "influences on error messages of Thrift Server and SQL CLI while running queries.") .version("3.4.0") - .stringConf.transform(_.toUpperCase(Locale.ROOT)) - .checkValues(ErrorMessageFormat.values.map(_.toString)) - .createWithDefault(ErrorMessageFormat.PRETTY.toString) + .enumConf(ErrorMessageFormat) + .createWithDefault(ErrorMessageFormat.PRETTY) val LATERAL_COLUMN_ALIAS_IMPLICIT_ENABLED = buildConf("spark.sql.lateralColumnAlias.enableImplicitResolution") @@ -6221,7 +6189,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def filesourcePartitionFileCacheSize: Long = getConf(HIVE_FILESOURCE_PARTITION_FILE_CACHE_SIZE) def caseSensitiveInferenceMode: HiveCaseSensitiveInferenceMode.Value = - HiveCaseSensitiveInferenceMode.withName(getConf(HIVE_CASE_SENSITIVE_INFERENCE)) + getConf(HIVE_CASE_SENSITIVE_INFERENCE) def gatherFastStats: Boolean = getConf(GATHER_FASTSTAT) @@ -6235,7 +6203,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def codegenFallback: Boolean = getConf(CODEGEN_FALLBACK) - def codegenFactoryMode: String = getConf(CODEGEN_FACTORY_MODE) + def codegenFactoryMode: CodegenObjectFactoryMode.Value = getConf(CODEGEN_FACTORY_MODE) def codegenComments: Boolean = getConf(StaticSQLConf.CODEGEN_COMMENTS) @@ -6311,9 +6279,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def legacyPostgresDatetimeMappingEnabled: Boolean = getConf(LEGACY_POSTGRES_DATETIME_MAPPING_ENABLED) - override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = { - LegacyBehaviorPolicy.withName(getConf(SQLConf.LEGACY_TIME_PARSER_POLICY)) - } + override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value = + getConf(SQLConf.LEGACY_TIME_PARSER_POLICY) def broadcastHashJoinOutputPartitioningExpandLimit: Int = getConf(BROADCAST_HASH_JOIN_OUTPUT_PARTITIONING_EXPAND_LIMIT) @@ -6367,9 +6334,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def isParquetINT96TimestampConversion: Boolean = getConf(PARQUET_INT96_TIMESTAMP_CONVERSION) - def parquetOutputTimestampType: ParquetOutputTimestampType.Value = { - ParquetOutputTimestampType.withName(getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE)) - } + def parquetOutputTimestampType: ParquetOutputTimestampType.Value = + getConf(PARQUET_OUTPUT_TIMESTAMP_TYPE) def writeLegacyParquetFormat: Boolean = getConf(PARQUET_WRITE_LEGACY_FORMAT) @@ -6647,10 +6613,10 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS) def partitionOverwriteMode: PartitionOverwriteMode.Value = - PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + getConf(PARTITION_OVERWRITE_MODE) def storeAssignmentPolicy: StoreAssignmentPolicy.Value = - StoreAssignmentPolicy.withName(getConf(STORE_ASSIGNMENT_POLICY)) + getConf(STORE_ASSIGNMENT_POLICY) override def ansiEnabled: Boolean = getConf(ANSI_ENABLED) @@ -6673,11 +6639,11 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def chunkBase64StringEnabled: Boolean = getConf(CHUNK_BASE64_STRING_ENABLED) def timestampType: AtomicType = getConf(TIMESTAMP_TYPE) match { - case "TIMESTAMP_LTZ" => + case TimestampTypes.TIMESTAMP_LTZ => // For historical reason, the TimestampType maps to TIMESTAMP WITH LOCAL TIME ZONE TimestampType - case "TIMESTAMP_NTZ" => + case TimestampTypes.TIMESTAMP_NTZ => TimestampNTZType } @@ -6827,8 +6793,7 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def histogramNumericPropagateInputType: Boolean = getConf(SQLConf.HISTOGRAM_NUMERIC_PROPAGATE_INPUT_TYPE) - def errorMessageFormat: ErrorMessageFormat.Value = - ErrorMessageFormat.withName(getConf(SQLConf.ERROR_MESSAGE_FORMAT)) + def errorMessageFormat: ErrorMessageFormat.Value = getConf(SQLConf.ERROR_MESSAGE_FORMAT) def defaultDatabase: String = getConf(StaticSQLConf.CATALOG_DEFAULT_DATABASE) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index f66b5bd988c2..65fafb5a34c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -58,7 +58,7 @@ private[sql] class AvroDeserializer( def this( rootAvroType: Schema, rootCatalystType: DataType, - datetimeRebaseMode: String, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, useStableIdForUnionType: Boolean, stableIdPrefixForUnionType: String, recursiveFieldMaxDepth: Int) = { @@ -66,7 +66,7 @@ private[sql] class AvroDeserializer( rootAvroType, rootCatalystType, positionalFieldMatch = false, - RebaseSpec(LegacyBehaviorPolicy.withName(datetimeRebaseMode)), + RebaseSpec(datetimeRebaseMode), new NoopFilters, useStableIdForUnionType, stableIdPrefixForUnionType, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala index d571b3ed6050..f63014d6ea9e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOptions.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} /** * Options for Avro Reader and Writer stored in case insensitive manner. @@ -128,8 +128,8 @@ private[sql] class AvroOptions( /** * The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads. */ - val datetimeRebaseModeInRead: String = parameters - .get(DATETIME_REBASE_MODE) + val datetimeRebaseModeInRead: LegacyBehaviorPolicy.Value = parameters + .get(DATETIME_REBASE_MODE).map(LegacyBehaviorPolicy.withName) .getOrElse(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_READ)) val useStableIdForUnionType: Boolean = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index c4aaacf51545..767216b81992 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -44,8 +44,7 @@ private[avro] class AvroOutputWriter( avroSchema: Schema) extends OutputWriter { // Whether to rebase datetimes from Gregorian to Julian calendar in write - private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( - SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)) + private val datetimeRebaseMode = SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE) // The input rows will never be null. private lazy val serializer = new AvroSerializer( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 1d83a46a278f..402bab666948 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -53,7 +53,7 @@ private[sql] class AvroSerializer( def this(rootCatalystType: DataType, rootAvroType: Schema, nullable: Boolean) = { this(rootCatalystType, rootAvroType, nullable, positionalFieldMatch = false, - LegacyBehaviorPolicy.withName(SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE))) + SQLConf.get.getConf(SQLConf.AVRO_REBASE_MODE_IN_WRITE)) } def serialize(catalystData: Any): Any = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 360a0bb2d0ce..21cf70dab59f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, perio import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand, ShowViewsCommand} import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTablesExec} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ @@ -52,7 +53,7 @@ object HiveResult extends SQLConfHelper { def getBinaryFormatter: BinaryFormatter = { if (conf.getConf(SQLConf.BINARY_OUTPUT_STYLE).isEmpty) { // Keep the legacy behavior for compatibility. - conf.setConf(SQLConf.BINARY_OUTPUT_STYLE, Some("UTF-8")) + conf.setConf(SQLConf.BINARY_OUTPUT_STYLE, Some(BinaryOutputStyle.UTF8)) } ToStringBase.getBinaryFormatter(_).toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 920f61574770..1ee467ef3554 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -974,8 +974,7 @@ case class CollapseCodegenStages( } def apply(plan: SparkPlan): SparkPlan = { - if (conf.wholeStageEnabled && CodegenObjectFactoryMode.withName(conf.codegenFactoryMode) - != CodegenObjectFactoryMode.NO_CODEGEN) { + if (conf.wholeStageEnabled && conf.codegenFactoryMode != CodegenObjectFactoryMode.NO_CODEGEN) { insertWholeStageCodegen(plan) } else { plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 875c5dfc5963..3e66b97f61a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -127,7 +127,7 @@ object DataSourceUtils extends PredicateHelper { private def getRebaseSpec( lookupFileMeta: String => String, - modeByConfig: String, + modeByConfig: LegacyBehaviorPolicy.Value, minVersion: String, metadataKey: String): RebaseSpec = { val policy = if (Utils.isTesting && @@ -145,7 +145,7 @@ object DataSourceUtils extends PredicateHelper { } else { LegacyBehaviorPolicy.CORRECTED } - }.getOrElse(LegacyBehaviorPolicy.withName(modeByConfig)) + }.getOrElse(modeByConfig) } policy match { case LegacyBehaviorPolicy.LEGACY => @@ -156,7 +156,7 @@ object DataSourceUtils extends PredicateHelper { def datetimeRebaseSpec( lookupFileMeta: String => String, - modeByConfig: String): RebaseSpec = { + modeByConfig: LegacyBehaviorPolicy.Value): RebaseSpec = { getRebaseSpec( lookupFileMeta, modeByConfig, @@ -166,7 +166,7 @@ object DataSourceUtils extends PredicateHelper { def int96RebaseSpec( lookupFileMeta: String => String, - modeByConfig: String): RebaseSpec = { + modeByConfig: LegacyBehaviorPolicy.Value): RebaseSpec = { getRebaseSpec( lookupFileMeta, modeByConfig, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index e795d156d764..eaedd99d8628 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -24,7 +24,7 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.catalyst.{DataSourceOptions, FileSourceOptions} import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} /** * Options for the Parquet data source. @@ -74,14 +74,15 @@ class ParquetOptions( /** * The rebasing mode for the DATE and TIMESTAMP_MICROS, TIMESTAMP_MILLIS values in reads. */ - def datetimeRebaseModeInRead: String = parameters + def datetimeRebaseModeInRead: LegacyBehaviorPolicy.Value = parameters .get(DATETIME_REBASE_MODE) + .map(LegacyBehaviorPolicy.withName) .getOrElse(sqlConf.getConf(SQLConf.PARQUET_REBASE_MODE_IN_READ)) /** * The rebasing mode for INT96 timestamp values in reads. */ - def int96RebaseModeInRead: String = parameters - .get(INT96_REBASE_MODE) + def int96RebaseModeInRead: LegacyBehaviorPolicy.Value = parameters + .get(INT96_REBASE_MODE).map(LegacyBehaviorPolicy.withName) .getOrElse(sqlConf.getConf(SQLConf.PARQUET_INT96_REBASE_MODE_IN_READ)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala index c325871cc82b..4022f7ea3003 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetWriteSupport.scala @@ -83,8 +83,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { private val decimalBuffer = new Array[Byte](Decimal.minBytesForPrecision(DecimalType.MAX_PRECISION)) - private val datetimeRebaseMode = LegacyBehaviorPolicy.withName( - SQLConf.get.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE)) + private val datetimeRebaseMode = SQLConf.get.getConf(SQLConf.PARQUET_REBASE_MODE_IN_WRITE) private val dateRebaseFunc = DataSourceUtils.createDateRebaseFuncInWrite( datetimeRebaseMode, "Parquet") @@ -92,8 +91,7 @@ class ParquetWriteSupport extends WriteSupport[InternalRow] with Logging { private val timestampRebaseFunc = DataSourceUtils.createTimestampRebaseFuncInWrite( datetimeRebaseMode, "Parquet") - private val int96RebaseMode = LegacyBehaviorPolicy.withName( - SQLConf.get.getConf(SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE)) + private val int96RebaseMode = SQLConf.get.getConf(SQLConf.PARQUET_INT96_REBASE_MODE_IN_WRITE) private val int96RebaseFunc = DataSourceUtils.createTimestampRebaseFuncInWrite( int96RebaseMode, "Parquet INT96") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 5c5efdbf6407..1684879612f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -3753,5 +3753,5 @@ class CSVLegacyTimeParserSuite extends CSVSuite { override protected def sparkConf: SparkConf = super .sparkConf - .set(SQLConf.LEGACY_TIME_PARSER_POLICY, "legacy") + .set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "legacy") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 52ffdf6c6c0b..eb803d04f153 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -4149,7 +4149,7 @@ class JsonLegacyTimeParserSuite extends JsonSuite { override protected def sparkConf: SparkConf = super .sparkConf - .set(SQLConf.LEGACY_TIME_PARSER_POLICY, "legacy") + .set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "legacy") } class JsonUnsafeRowSuite extends JsonSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 0a50f07a1b2b..442dd09ce388 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.MIT import org.apache.spark.sql.classic.{SparkSession, SQLContext} import org.apache.spark.sql.execution.datasources.parquet.ParquetCompressionCodec.{GZIP, LZO} +import org.apache.spark.sql.internal.SQLConf.ParquetOutputTimestampType import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.sql.test.{SharedSparkSession, TestSQLContext} import org.apache.spark.util.Utils @@ -351,10 +352,11 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { assert(spark.sessionState.conf.parquetOutputTimestampType == SQLConf.ParquetOutputTimestampType.INT96) - sqlConf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "timestamp_micros") + sqlConf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, + ParquetOutputTimestampType.TIMESTAMP_MICROS) assert(spark.sessionState.conf.parquetOutputTimestampType == - SQLConf.ParquetOutputTimestampType.TIMESTAMP_MICROS) - sqlConf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, "int96") + ParquetOutputTimestampType.TIMESTAMP_MICROS) + sqlConf.setConf(SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE, ParquetOutputTimestampType.INT96) assert(spark.sessionState.conf.parquetOutputTimestampType == SQLConf.ParquetOutputTimestampType.INT96) @@ -514,4 +516,14 @@ class SQLConfSuite extends QueryTest with SharedSparkSession { condition = "SQL_CONF_NOT_FOUND", parameters = Map("sqlConf" -> "\"some.conf\"")) } + + test("SPARK-51874: Add Enum support to ConfigBuilder") { + assert(spark.conf.get(SQLConf.LEGACY_TIME_PARSER_POLICY) === LegacyBehaviorPolicy.CORRECTED) + val e = intercept[IllegalArgumentException] { + spark.conf.set(SQLConf.LEGACY_TIME_PARSER_POLICY.key, "invalid") + } + assert(e.getMessage === + s"${SQLConf.LEGACY_TIME_PARSER_POLICY.key} should be one of " + + s"${LegacyBehaviorPolicy.values.mkString(", ")}, but was invalid") + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 539a0c12c05a..4f5f50433ea1 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -63,7 +63,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive doesn't follow ANSI Standard. TestHive.setConf(SQLConf.ANSI_ENABLED, false) // Ensures that the table insertion behavior is consistent with Hive - TestHive.setConf(SQLConf.STORE_ASSIGNMENT_POLICY, StoreAssignmentPolicy.LEGACY.toString) + TestHive.setConf(SQLConf.STORE_ASSIGNMENT_POLICY, StoreAssignmentPolicy.LEGACY) // Fix session local timezone to America/Los_Angeles for those timezone sensitive tests // (timestamp_*) TestHive.setConf(SQLConf.SESSION_LOCAL_TIMEZONE, "America/Los_Angeles") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala index 1eb3f6f3c9cb..7bbd4fac3d0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSchemaInferenceSuite.scala @@ -158,8 +158,6 @@ class HiveSchemaInferenceSuite SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key -> mode.toString)(f) } - private val inferenceKey = SQLConf.HIVE_CASE_SENSITIVE_INFERENCE.key - private def testFieldQuery(fields: Seq[String]): Unit = { if (!fields.isEmpty) { val query = s"SELECT * FROM ${TEST_TABLE_NAME} WHERE ${Random.shuffle(fields).head} >= 0" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org