This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 38c6ef456ed6 [SPARK-50529][SQL] Change char/varchar behavior under the
`spark.sql.preserveCharVarcharTypeInfo` config
38c6ef456ed6 is described below
commit 38c6ef456ed67ed55e186625dcde017d85d6431f
Author: Jovan Markovic <[email protected]>
AuthorDate: Thu Dec 26 13:13:06 2024 +0800
[SPARK-50529][SQL] Change char/varchar behavior under the
`spark.sql.preserveCharVarcharTypeInfo` config
### What changes were proposed in this pull request?
This PR changes char/varchar behaviour under the
`PRESERVE_CHAR_VARCHAR_TYPE_INFO` configuration flag, (exposed as
`spark.sql.preserveCharVarcharTypeInfo`).
### Why are the changes needed?
This PR enables the improvement of char/varchar types in a backwards
compatible way.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added tests in:
- `RowEncoderSuite`
- `LiteralExpressionSuite`
- `CharVarcharTestSuite`
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49128 from jovanm-db/char_varchar_conf.
Authored-by: Jovan Markovic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../main/scala/org/apache/spark/sql/Encoders.scala | 14 ++++++
.../sql/catalyst/encoders/AgnosticEncoder.scala | 2 +
.../spark/sql/catalyst/encoders/RowEncoder.scala | 12 ++---
.../sql/catalyst/util/SparkCharVarcharUtils.scala | 3 +-
.../org/apache/spark/sql/internal/SqlApiConf.scala | 2 +
.../sql/catalyst/CatalystTypeConverters.scala | 29 ++++++++++++
.../sql/catalyst/DeserializerBuildHelper.scala | 34 ++++++++++++++-
.../spark/sql/catalyst/SerializerBuildHelper.scala | 24 +++++++++-
.../sql/catalyst/analysis/CheckAnalysis.scala | 6 ++-
.../sql/catalyst/encoders/ExpressionEncoder.scala | 6 ++-
.../spark/sql/catalyst/expressions/literals.scala | 12 +++--
.../spark/sql/catalyst/util/CharVarcharUtils.scala | 12 ++++-
.../org/apache/spark/sql/internal/SQLConf.scala | 10 +++++
.../sql/catalyst/analysis/AnalysisSuite.scala | 18 ++++++++
.../sql/catalyst/encoders/RowEncoderSuite.scala | 37 ++++++++++++++++
.../expressions/LiteralExpressionSuite.scala | 51 +++++++++++++++++-----
.../catalyst/expressions/ToPrettyStringSuite.scala | 4 ++
.../execution/command/AnalyzeColumnCommand.scala | 1 +
.../apache/spark/sql/CharVarcharTestSuite.scala | 21 +++++++++
19 files changed, 265 insertions(+), 33 deletions(-)
diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
index 9976b34f7a01..4957d76af9a2 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -81,6 +81,20 @@ object Encoders {
*/
def DOUBLE: Encoder[java.lang.Double] = BoxedDoubleEncoder
+ /**
+ * An encoder for nullable char type.
+ *
+ * @since 4.0.0
+ */
+ def CHAR(length: Int): Encoder[java.lang.String] = CharEncoder(length)
+
+ /**
+ * An encoder for nullable varchar type.
+ *
+ * @since 4.0.0
+ */
+ def VARCHAR(length: Int): Encoder[java.lang.String] = VarcharEncoder(length)
+
/**
* An encoder for nullable string type.
*
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
index 9ae7de97abf5..d998502ac1b2 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/AgnosticEncoder.scala
@@ -231,6 +231,8 @@ object AgnosticEncoders {
// Nullable leaf encoders
case object NullEncoder extends LeafEncoder[java.lang.Void](NullType)
case object StringEncoder extends LeafEncoder[String](StringType)
+ case class CharEncoder(length: Int) extends
LeafEncoder[String](CharType(length))
+ case class VarcharEncoder(length: Int) extends
LeafEncoder[String](VarcharType(length))
case object BinaryEncoder extends LeafEncoder[Array[Byte]](BinaryType)
case object ScalaBigIntEncoder extends
LeafEncoder[BigInt](DecimalType.BigIntDecimal)
case object JavaBigIntEncoder extends
LeafEncoder[JBigInt](DecimalType.BigIntDecimal)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 3f384235ff32..718d99043abf 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.classTag
import org.apache.spark.sql.{AnalysisException, Row}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder,
TimestampEncoder, UDTEncoder, VariantE [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder,
CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, InstantEncoder,
IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder,
MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder,
TimestampEncoder, UDTEnco [...]
import org.apache.spark.sql.errors.{DataTypeErrorsBase, ExecutionErrors}
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types._
@@ -80,11 +80,11 @@ object RowEncoder extends DataTypeErrorsBase {
case DoubleType => BoxedDoubleEncoder
case dt: DecimalType => JavaDecimalEncoder(dt, lenientSerialization =
true)
case BinaryType => BinaryEncoder
- case CharType(_) | VarcharType(_) =>
- throw new AnalysisException(
- errorClass = "UNSUPPORTED_DATA_TYPE_FOR_ENCODER",
- messageParameters = Map("dataType" -> toSQLType(dataType)))
- case _: StringType => StringEncoder
+ case CharType(length) if SqlApiConf.get.preserveCharVarcharTypeInfo =>
+ CharEncoder(length)
+ case VarcharType(length) if SqlApiConf.get.preserveCharVarcharTypeInfo =>
+ VarcharEncoder(length)
+ case s: StringType if s.constraint == NoConstraint => StringEncoder
case TimestampType if SqlApiConf.get.datetimeJava8ApiEnabled =>
InstantEncoder(lenient)
case TimestampType => TimestampEncoder(lenient)
case TimestampNTZType => LocalDateTimeEncoder
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
index 2a26c079e8d4..51b2c40f9bf2 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/SparkCharVarcharUtils.scala
@@ -54,8 +54,7 @@ trait SparkCharVarcharUtils {
StructType(fields.map { field =>
field.copy(dataType = replaceCharVarcharWithString(field.dataType))
})
- case _: CharType => StringType
- case _: VarcharType => StringType
+ case CharType(_) | VarcharType(_) if
!SqlApiConf.get.preserveCharVarcharTypeInfo => StringType
case _ => dt
}
}
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index d5668cc72175..76cd436b39b5 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -40,6 +40,7 @@ private[sql] trait SqlApiConf {
def timestampType: AtomicType
def allowNegativeScaleOfDecimalEnabled: Boolean
def charVarcharAsString: Boolean
+ def preserveCharVarcharTypeInfo: Boolean
def datetimeJava8ApiEnabled: Boolean
def sessionLocalTimeZone: String
def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value
@@ -80,6 +81,7 @@ private[sql] object DefaultSqlApiConf extends SqlApiConf {
override def timestampType: AtomicType = TimestampType
override def allowNegativeScaleOfDecimalEnabled: Boolean = false
override def charVarcharAsString: Boolean = false
+ override def preserveCharVarcharTypeInfo: Boolean = false
override def datetimeJava8ApiEnabled: Boolean = false
override def sessionLocalTimeZone: String = TimeZone.getDefault.getID
override def legacyTimeParserPolicy: LegacyBehaviorPolicy.Value =
LegacyBehaviorPolicy.CORRECTED
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index 2b2a186f76d9..fab65251ed51 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -66,6 +66,8 @@ object CatalystTypeConverters {
case arrayType: ArrayType => ArrayConverter(arrayType.elementType)
case mapType: MapType => MapConverter(mapType.keyType, mapType.valueType)
case structType: StructType => StructConverter(structType)
+ case CharType(length) => new CharConverter(length)
+ case VarcharType(length) => new VarcharConverter(length)
case _: StringType => StringConverter
case DateType if SQLConf.get.datetimeJava8ApiEnabled =>
LocalDateConverter
case DateType => DateConverter
@@ -296,6 +298,33 @@ object CatalystTypeConverters {
toScala(row.getStruct(column, structType.size))
}
+ private class CharConverter(length: Int) extends CatalystTypeConverter[Any,
String, UTF8String] {
+ override def toCatalystImpl(scalaValue: Any): UTF8String =
+ CharVarcharCodegenUtils.charTypeWriteSideCheck(
+ StringConverter.toCatalystImpl(scalaValue), length)
+ override def toScala(catalystValue: UTF8String): String = if
(catalystValue == null) {
+ null
+ } else {
+ CharVarcharCodegenUtils.charTypeWriteSideCheck(catalystValue,
length).toString
+ }
+ override def toScalaImpl(row: InternalRow, column: Int): String =
+
CharVarcharCodegenUtils.charTypeWriteSideCheck(row.getUTF8String(column),
length).toString
+ }
+
+ private class VarcharConverter(length: Int)
+ extends CatalystTypeConverter[Any, String, UTF8String] {
+ override def toCatalystImpl(scalaValue: Any): UTF8String =
+ CharVarcharCodegenUtils.varcharTypeWriteSideCheck(
+ StringConverter.toCatalystImpl(scalaValue), length)
+ override def toScala(catalystValue: UTF8String): String = if
(catalystValue == null) {
+ null
+ } else {
+ CharVarcharCodegenUtils.varcharTypeWriteSideCheck(catalystValue,
length).toString
+ }
+ override def toScalaImpl(row: InternalRow, column: Int): String =
+
CharVarcharCodegenUtils.varcharTypeWriteSideCheck(row.getUTF8String(column),
length).toString
+ }
+
private object StringConverter extends CatalystTypeConverter[Any, String,
UTF8String] {
override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue
match {
case str: String => UTF8String.fromString(str)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
index 475243401537..55613b2b2013 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.{expressions => exprs}
import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal,
UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedLeafEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder,
IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder,
JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder,
OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder,
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder,
PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncoder, ScalaBigIn [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedLeafEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder,
InstantEncoder, IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder,
JavaDecimalEncoder, JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder,
MapEncoder, OptionEncoder, PrimitiveBooleanEncoder, PrimitiveByteEncoder,
PrimitiveDoubleEncoder, PrimitiveFloatEncoder, PrimitiveIntEncoder,
PrimitiveLongEncoder, PrimitiveShortEncoder, ProductEncode [...]
import
org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor,
isNativeEncoder}
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField,
IsNull, Literal, MapKeys, MapValues, UpCast}
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
CreateExternalRow, DecodeUsingSerializer, InitializeJavaBean, Invoke,
NewInstance, StaticInvoke, UnresolvedCatalystToExternalMap,
UnresolvedMapObjects, WrapOption}
-import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils,
IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData,
CharVarcharCodegenUtils, DateTimeUtils, IntervalUtils}
import org.apache.spark.sql.types._
object DeserializerBuildHelper {
@@ -80,6 +80,32 @@ object DeserializerBuildHelper {
returnNullable = false)
}
+ def createDeserializerForChar(
+ path: Expression,
+ returnNullable: Boolean,
+ length: Int): Expression = {
+ val expr = StaticInvoke(
+ classOf[CharVarcharCodegenUtils],
+ StringType,
+ "charTypeWriteSideCheck",
+ path :: Literal(length) :: Nil,
+ returnNullable = returnNullable)
+ createDeserializerForString(expr, returnNullable)
+ }
+
+ def createDeserializerForVarchar(
+ path: Expression,
+ returnNullable: Boolean,
+ length: Int): Expression = {
+ val expr = StaticInvoke(
+ classOf[CharVarcharCodegenUtils],
+ StringType,
+ "varcharTypeWriteSideCheck",
+ path :: Literal(length) :: Nil,
+ returnNullable = returnNullable)
+ createDeserializerForString(expr, returnNullable)
+ }
+
def createDeserializerForString(path: Expression, returnNullable: Boolean):
Expression = {
Invoke(path, "toString", ObjectType(classOf[java.lang.String]),
returnNullable = returnNullable)
@@ -258,6 +284,10 @@ object DeserializerBuildHelper {
"withName",
createDeserializerForString(path, returnNullable = false) :: Nil,
returnNullable = false)
+ case CharEncoder(length) =>
+ createDeserializerForChar(path, returnNullable = false, length)
+ case VarcharEncoder(length) =>
+ createDeserializerForVarchar(path, returnNullable = false, length)
case StringEncoder =>
createDeserializerForString(path, returnNullable = false)
case _: ScalaDecimalEncoder =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
index daebe15c298f..089d463ecacb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SerializerBuildHelper.scala
@@ -22,11 +22,11 @@ import scala.language.existentials
import org.apache.spark.sql.catalyst.{expressions => exprs}
import
org.apache.spark.sql.catalyst.DeserializerBuildHelper.expressionWithNullSafety
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder,
AgnosticEncoders, Codec, JavaSerializationCodec, KryoSerializationCodec}
-import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder,
DateEncoder, DayTimeIntervalEncoder, InstantEncoder, IterableEncoder,
JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder, JavaEnumEncoder,
LocalDateEncoder, LocalDateTimeEncoder, MapEncoder, OptionEncoder,
PrimitiveLeafEncoder, ProductEncoder, ScalaBigIntEncod [...]
+import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder,
BoxedIntEncoder, BoxedLeafEncoder, BoxedLongEncoder, BoxedShortEncoder,
CharEncoder, DateEncoder, DayTimeIntervalEncoder, InstantEncoder,
IterableEncoder, JavaBeanEncoder, JavaBigIntEncoder, JavaDecimalEncoder,
JavaEnumEncoder, LocalDateEncoder, LocalDateTimeEncoder, MapEncoder,
OptionEncoder, PrimitiveLeafEncoder, ProductEncoder, Sca [...]
import
org.apache.spark.sql.catalyst.encoders.EncoderUtils.{externalDataTypeFor,
isNativeEncoder, lenientExternalDataTypeFor}
import org.apache.spark.sql.catalyst.expressions.{BoundReference,
CheckOverflow, CreateNamedStruct, Expression, IsNull, KnownNotNull, Literal,
UnsafeArrayData}
import org.apache.spark.sql.catalyst.expressions.objects._
-import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils,
GenericArrayData, IntervalUtils}
+import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils,
DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -63,6 +63,24 @@ object SerializerBuildHelper {
Invoke(inputObject, "doubleValue", DoubleType)
}
+ def createSerializerForChar(inputObject: Expression, length: Int):
Expression = {
+ StaticInvoke(
+ classOf[CharVarcharCodegenUtils],
+ CharType(length),
+ "charTypeWriteSideCheck",
+ createSerializerForString(inputObject) :: Literal(length) :: Nil,
+ returnNullable = false)
+ }
+
+ def createSerializerForVarchar(inputObject: Expression, length: Int):
Expression = {
+ StaticInvoke(
+ classOf[CharVarcharCodegenUtils],
+ VarcharType(length),
+ "varcharTypeWriteSideCheck",
+ createSerializerForString(inputObject) :: Literal(length) :: Nil,
+ returnNullable = false)
+ }
+
def createSerializerForString(inputObject: Expression): Expression = {
StaticInvoke(
classOf[UTF8String],
@@ -298,6 +316,8 @@ object SerializerBuildHelper {
case BoxedDoubleEncoder => createSerializerForDouble(input)
case JavaEnumEncoder(_) => createSerializerForJavaEnum(input)
case ScalaEnumEncoder(_, _) => createSerializerForScalaEnum(input)
+ case CharEncoder(length) => createSerializerForChar(input, length)
+ case VarcharEncoder(length) => createSerializerForVarchar(input, length)
case StringEncoder => createSerializerForString(input)
case ScalaDecimalEncoder(dt) => createSerializerForBigDecimal(input, dt)
case JavaDecimalEncoder(dt, false) => createSerializerForBigDecimal(input,
dt)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index b0d6a2a46baa..6cd394fd79e9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -283,9 +283,11 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
plan.foreachUp {
case p if p.analyzed => // Skip already analyzed sub-plans
- case leaf: LeafNode if
leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) =>
+ case leaf: LeafNode if !SQLConf.get.preserveCharVarcharTypeInfo &&
+ leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) =>
throw SparkException.internalError(
- "Logical plan should not have output of char/varchar type: " + leaf)
+ s"Logical plan should not have output of char/varchar type when " +
+ s"${SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key} is false: " +
leaf)
case u: UnresolvedNamespace =>
u.schemaNotFound(u.multipartIdentifier)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index f2f86a90d517..5f0b42fec0fa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -87,7 +87,8 @@ object ExpressionEncoder {
}
constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
} catch {
- case e: SparkRuntimeException if e.getCondition ==
"NOT_NULL_ASSERT_VIOLATION" =>
+ case e: SparkRuntimeException if e.getCondition ==
"NOT_NULL_ASSERT_VIOLATION" ||
+ e.getCondition == "EXCEED_LIMIT_LENGTH" =>
throw e
case e: Exception =>
throw QueryExecutionErrors.expressionDecodingError(e, expressions)
@@ -115,7 +116,8 @@ object ExpressionEncoder {
inputRow(0) = t
extractProjection(inputRow)
} catch {
- case e: SparkRuntimeException if e.getCondition ==
"NOT_NULL_ASSERT_VIOLATION" =>
+ case e: SparkRuntimeException if e.getCondition ==
"NOT_NULL_ASSERT_VIOLATION" ||
+ e.getCondition == "EXCEED_LIMIT_LENGTH" =>
throw e
case e: Exception =>
throw QueryExecutionErrors.expressionEncodingError(e, expressions)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index b874cb53cb31..f3bed39bcb9f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -166,6 +166,8 @@ object Literal {
case _: DayTimeIntervalType if v.isInstanceOf[Duration] =>
Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v),
dataType)
case _: ObjectType => Literal(v, dataType)
+ case CharType(_) | VarcharType(_) if
SQLConf.get.preserveCharVarcharTypeInfo =>
+ Literal(CatalystTypeConverters.createToCatalystConverter(dataType)(v),
dataType)
case _ => Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}
}
@@ -196,9 +198,13 @@ object Literal {
case TimestampNTZType => create(0L, TimestampNTZType)
case it: DayTimeIntervalType => create(0L, it)
case it: YearMonthIntervalType => create(0, it)
- case CharType(_) | VarcharType(_) =>
- throw QueryExecutionErrors.noDefaultForDataTypeError(dataType)
- case st: StringType => Literal(UTF8String.fromString(""), st)
+ case CharType(length) =>
+
create(CharVarcharCodegenUtils.charTypeWriteSideCheck(UTF8String.fromString(""),
length),
+ dataType)
+ case VarcharType(length) =>
+
create(CharVarcharCodegenUtils.varcharTypeWriteSideCheck(UTF8String.fromString(""),
length),
+ dataType)
+ case st: StringType if st.constraint == NoConstraint =>
Literal(UTF8String.fromString(""), st)
case BinaryType => Literal("".getBytes(StandardCharsets.UTF_8))
case CalendarIntervalType => Literal(new CalendarInterval(0, 0, 0))
case arr: ArrayType => create(Array(), arr)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
index 628fdcebd308..3db0f54f1a8f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala
@@ -164,7 +164,11 @@ object CharVarcharUtils extends Logging with
SparkCharVarcharUtils {
case CharType(length) if charFuncName.isDefined =>
StaticInvoke(
classOf[CharVarcharCodegenUtils],
- StringType,
+ if (SQLConf.get.preserveCharVarcharTypeInfo) {
+ CharType(length)
+ } else {
+ StringType
+ },
charFuncName.get,
expr :: Literal(length) :: Nil,
returnNullable = false)
@@ -172,7 +176,11 @@ object CharVarcharUtils extends Logging with
SparkCharVarcharUtils {
case VarcharType(length) if varcharFuncName.isDefined =>
StaticInvoke(
classOf[CharVarcharCodegenUtils],
- StringType,
+ if (SQLConf.get.preserveCharVarcharTypeInfo) {
+ VarcharType(length)
+ } else {
+ StringType
+ },
varcharFuncName.get,
expr :: Literal(length) :: Nil,
returnNullable = false)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 08f77d58979f..5e630577638a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4938,6 +4938,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val PRESERVE_CHAR_VARCHAR_TYPE_INFO =
buildConf("spark.sql.preserveCharVarcharTypeInfo")
+ .doc("When true, Spark does not replace CHAR/VARCHAR types the STRING
type, which is the " +
+ "default behavior of Spark 3.0 and earlier versions. This means the
length checks for " +
+ "CHAR/VARCHAR types is enforced and CHAR type is also properly padded.")
+ .version("4.0.0")
+ .booleanConf
+ .createWithDefault(false)
+
val READ_SIDE_CHAR_PADDING = buildConf("spark.sql.readSideCharPadding")
.doc("When true, Spark applies string padding when reading CHAR type
columns/fields, " +
"in addition to the write-side padding. This config is true by default
to better enforce " +
@@ -6343,6 +6351,8 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def charVarcharAsString: Boolean =
getConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING)
+ def preserveCharVarcharTypeInfo: Boolean =
getConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO)
+
def readSideCharPadding: Boolean = getConf(SQLConf.READ_SIDE_CHAR_PADDING)
def cliPrintHeader: Boolean = getConf(SQLConf.CLI_PRINT_HEADER)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ae27985a3ba6..2ffe6de974c7 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -81,6 +81,24 @@ class AnalysisSuite extends AnalysisTest with Matchers {
}
}
+ test(s"do not fail if a leaf node has char/varchar type output and " +
+ s"${SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key} is true") {
+ withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") {
+ val schema1 = new StructType().add("c", CharType(5))
+ val schema2 = new StructType().add("c", VarcharType(5))
+ val schema3 = new StructType().add("c", ArrayType(CharType(5)))
+ Seq(schema1, schema2, schema3).foreach { schema =>
+ val table = new InMemoryTable("t", schema, Array.empty,
Map.empty[String, String].asJava)
+ DataSourceV2Relation(
+ table,
+ DataTypeUtils.toAttributes(schema),
+ None,
+ None,
+ CaseInsensitiveStringMap.empty()).analyze
+ }
+ }
+ }
+
test("union project *") {
val plan = (1 to 120)
.map(_ => testRelation)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 79c6d07d6d21..645b80ffaacb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -488,4 +488,41 @@ class RowEncoderSuite extends CodegenInterpretedPlanTest {
val data = Row(mutable.ArraySeq.make(Array(Row("key", "value".getBytes))))
val row = encoder.createSerializer()(data)
}
+
+ test("do not allow serializing too long strings into char/varchar") {
+ Seq(CharType(5), VarcharType(5)).foreach { typ =>
+ withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") {
+ val schema = new StructType().add("c", typ)
+ val encoder = ExpressionEncoder(schema).resolveAndBind()
+ val value = "abcdef"
+ checkError(
+ exception = intercept[SparkRuntimeException]({
+ val row = toRow(encoder, Row(value))
+ }),
+ condition = "EXCEED_LIMIT_LENGTH",
+ parameters = Map("limit" -> "5")
+ )
+ }
+ }
+ }
+
+ test("do not allow deserializing too long strings into char/varchar") {
+ Seq(CharType(5), VarcharType(5)).foreach { typ =>
+ withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") {
+ val fromSchema = new StructType().add("c", StringType)
+ val fromEncoder = ExpressionEncoder(fromSchema).resolveAndBind()
+ val toSchema = new StructType().add("c", typ)
+ val toEncoder = ExpressionEncoder(toSchema).resolveAndBind()
+ val value = "abcdef"
+ val row = toRow(fromEncoder, Row(value))
+ checkError(
+ exception = intercept[SparkRuntimeException]({
+ val value = fromRow(toEncoder, row)
+ }),
+ condition = "EXCEED_LIMIT_LENGTH",
+ parameters = Map("limit" -> "5")
+ )
+ }
+ }
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
index b351d69d3a0b..5da5c6ac412c 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala
@@ -25,13 +25,12 @@ import java.util.TimeZone
import scala.collection.mutable
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.types.DayTimeIntervalType._
@@ -91,16 +90,8 @@ class LiteralExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
// ExamplePointUDT.sqlType is ArrayType(DoubleType, false).
checkEvaluation(Literal.default(new ExamplePointUDT), Array())
- // DateType without default value`
- List(CharType(1), VarcharType(1)).foreach(errType => {
- checkError(
- exception = intercept[SparkException] {
- Literal.default(errType)
- },
- condition = "INTERNAL_ERROR",
- parameters = Map("message" -> s"No default value for type:
${toSQLType(errType)}.")
- )
- })
+ checkEvaluation(Literal.default(CharType(5)), " ")
+ checkEvaluation(Literal.default(VarcharType(5)), "")
}
test("boolean literals") {
@@ -160,6 +151,42 @@ class LiteralExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
checkEvaluation(Literal.create("\u0000"), "\u0000")
}
+ test("char literals") {
+ withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") {
+ val typ = CharType(5)
+ checkEvaluation(Literal.create("", typ), " ")
+ checkEvaluation(Literal.create("test", typ), "test ")
+ checkEvaluation(Literal.create("test ", typ), "test ")
+ checkEvaluation(Literal.create("\u0000", typ), "\u0000 ")
+
+ checkError(
+ exception = intercept[SparkRuntimeException]({
+ Literal.create("123456", typ)
+ }),
+ condition = "EXCEED_LIMIT_LENGTH",
+ parameters = Map("limit" -> "5")
+ )
+ }
+ }
+
+ test("varchar literals") {
+ withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") {
+ val typ = VarcharType(5)
+ checkEvaluation(Literal.create("", typ), "")
+ checkEvaluation(Literal.create("test", typ), "test")
+ checkEvaluation(Literal.create("test ", typ), "test ")
+ checkEvaluation(Literal.create("\u0000", typ), "\u0000")
+
+ checkError(
+ exception = intercept[SparkRuntimeException]({
+ Literal.create("123456", typ)
+ }),
+ condition = "EXCEED_LIMIT_LENGTH",
+ parameters = Map("limit" -> "5")
+ )
+ }
+ }
+
test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
index 783fba3bfc0d..2a5f76cab361 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ToPrettyStringSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils.UTC_OPT
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
@@ -89,6 +90,9 @@ class ToPrettyStringSuite extends SparkFunSuite with
ExpressionEvalHelper {
test("Char as pretty strings") {
checkEvaluation(ToPrettyString(Literal.create('a', CharType(5))), "a")
+ withSQLConf(SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key -> "true") {
+ checkEvaluation(ToPrettyString(Literal.create('a', CharType(5))), "a
")
+ }
}
test("Byte as pretty strings") {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
index 23555c98135f..1268b14a32fb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala
@@ -140,6 +140,7 @@ case class AnalyzeColumnCommand(
case DoubleType | FloatType => true
case BooleanType => true
case _: DatetimeType => true
+ case CharType(_) | VarcharType(_) => false
case BinaryType | _: StringType => true
case _ => false
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
index d3b11274fe1c..a5cbeb552dcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala
@@ -86,6 +86,27 @@ trait CharVarcharTestSuite extends QueryTest with
SQLTestUtils {
}
}
+ test("preserve char/varchar type info") {
+ Seq(CharType(5), VarcharType(5)).foreach { typ =>
+ for {
+ char_varchar_as_string <- Seq(false, true)
+ preserve_char_varchar <- Seq(false, true)
+ } {
+ withSQLConf(SQLConf.LEGACY_CHAR_VARCHAR_AS_STRING.key ->
char_varchar_as_string.toString,
+ SQLConf.PRESERVE_CHAR_VARCHAR_TYPE_INFO.key ->
preserve_char_varchar.toString) {
+ withTable("t") {
+ val name = typ.typeName
+ sql(s"CREATE TABLE t(i STRING, c $name) USING $format")
+ val schema = spark.table("t").schema
+ assert(schema.fields(0).dataType == StringType)
+ val expectedType = if (preserve_char_varchar) typ else StringType
+ assert(schema.fields(1).dataType == expectedType)
+ }
+ }
+ }
+ }
+ }
+
test("char type values should be padded or trimmed: partitioned columns") {
// via dynamic partitioned columns
withTable("t") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]