This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 484e7ac9acef [SPARK-48472][SQL] Enable reflect expressions with
collated strings
484e7ac9acef is described below
commit 484e7ac9acefc46ba6f1bc3019251441d9bb1507
Author: Mihailo Aleksic <[email protected]>
AuthorDate: Wed Jun 19 16:23:47 2024 +0800
[SPARK-48472][SQL] Enable reflect expressions with collated strings
### What changes were proposed in this pull request?
Changes made in this pull request enable collation of strings in "reflect"
expressions.
### Why are the changes needed?
Changes are bug fix which enable users to use feature mentioned above.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Using unit test which can be found in
sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #46991 from mihailoale-db/ReflectionCollationFix.
Authored-by: Mihailo Aleksic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../expressions/CallMethodViaReflection.scala | 23 ++++++----
.../spark/sql/CollationSQLExpressionsSuite.scala | 51 ++++++++++++++++++++++
2 files changed, 66 insertions(+), 8 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
index c42b54222f17..13ea8c77c41b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala
@@ -26,6 +26,8 @@ import
org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
@@ -77,12 +79,12 @@ case class CallMethodViaReflection(
)
} else {
val unexpectedParameter = children.zipWithIndex.collectFirst {
- case (e, 0) if !(e.dataType == StringType && e.foldable) =>
+ case (e, 0) if !(e.dataType.isInstanceOf[StringType] && e.foldable) =>
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("class"),
- "inputType" -> toSQLType(StringType),
+ "inputType" -> toSQLType(StringTypeAnyCollation),
"inputExpr" -> toSQLExpr(children.head)
)
)
@@ -90,12 +92,12 @@ case class CallMethodViaReflection(
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> toSQLId("class")))
- case (e, 1) if !(e.dataType == StringType && e.foldable) =>
+ case (e, 1) if !(e.dataType.isInstanceOf[StringType] && e.foldable) =>
DataTypeMismatch(
errorSubClass = "NON_FOLDABLE_INPUT",
messageParameters = Map(
"inputName" -> toSQLId("method"),
- "inputType" -> toSQLType(StringType),
+ "inputType" -> toSQLType(StringTypeAnyCollation),
"inputExpr" -> toSQLExpr(children(1))
)
)
@@ -103,14 +105,16 @@ case class CallMethodViaReflection(
DataTypeMismatch(
errorSubClass = "UNEXPECTED_NULL",
messageParameters = Map("exprName" -> toSQLId("method")))
- case (e, idx) if idx > 1 &&
!CallMethodViaReflection.typeMapping.contains(e.dataType) =>
+ case (e, idx) if idx > 1 &&
+ (!CallMethodViaReflection.typeMapping.contains(e.dataType)
+ && !e.dataType.isInstanceOf[StringType]) =>
DataTypeMismatch(
errorSubClass = "UNEXPECTED_INPUT_TYPE",
messageParameters = Map(
"paramIndex" -> ordinalNumber(idx),
"requiredType" -> toSQLType(
TypeCollection(BooleanType, ByteType, ShortType,
- IntegerType, LongType, FloatType, DoubleType, StringType)),
+ IntegerType, LongType, FloatType, DoubleType,
StringTypeAnyCollation)),
"inputSql" -> toSQLExpr(e),
"inputType" -> toSQLType(e.dataType))
)
@@ -134,7 +138,7 @@ case class CallMethodViaReflection(
}
override def nullable: Boolean = true
- override val dataType: DataType = StringType
+ override val dataType: DataType = SQLConf.get.defaultStringType
override protected def initializeInternal(partitionIndex: Int): Unit = {}
override protected def evalInternal(input: InternalRow): Any = {
@@ -230,7 +234,10 @@ object CallMethodViaReflection {
// Argument type must match. That is, either the method's argument
type matches one of the
// acceptable types defined in typeMapping, or it is a super type of
the acceptable types.
candidateTypes.zip(argTypes).forall { case (candidateType, argType) =>
- typeMapping(argType).exists(candidateType.isAssignableFrom)
+ if (!argType.isInstanceOf[StringType]) {
+ typeMapping(argType).exists(candidateType.isAssignableFrom)
+ }
+ else candidateType.isAssignableFrom(classOf[String])
}
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index 0c54ccb7cfb1..0a7b513457a5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -23,6 +23,7 @@ import java.text.SimpleDateFormat
import scala.collection.immutable.Seq
import org.apache.spark.{SparkConf, SparkException,
SparkIllegalArgumentException, SparkRuntimeException}
+import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types._
@@ -2020,6 +2021,56 @@ class CollationSQLExpressionsSuite
})
}
+ test("Reflect expressions with collated strings") {
+ // be aware that output of java.util.UUID.fromString is always lowercase
+
+ case class ReflectExpressions(
+ left: String,
+ leftCollation: String,
+ right: String,
+ rightCollation: String,
+ result: Boolean
+ )
+
+ val testCases = Seq(
+ ReflectExpressions("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary",
+ "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", true),
+ ReflectExpressions("a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary",
+ "A5Cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_binary", false),
+
+ ReflectExpressions("A5cf6C42-0C85-418f-af6c-3E4E5b1328f2", "utf8_binary",
+ "a5cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_lcase", true),
+ ReflectExpressions("A5cf6C42-0C85-418f-af6c-3E4E5b1328f2", "utf8_binary",
+ "A5Cf6c42-0c85-418f-af6c-3e4e5b1328f2", "utf8_lcase", true)
+ )
+ testCases.foreach(testCase => {
+ val query =
+ s"""
+ |SELECT REFLECT('java.util.UUID', 'fromString',
+ |collate('${testCase.left}', '${testCase.leftCollation}'))=
+ |collate('${testCase.right}', '${testCase.rightCollation}');
+ |""".stripMargin
+ val testQuery = sql(query)
+ checkAnswer(testQuery, Row(testCase.result))
+ })
+
+ val queryPass =
+ s"""
+ |SELECT REFLECT('java.lang.Integer', 'toHexString',2);
+ |""".stripMargin
+ val testQueryPass = sql(queryPass)
+ checkAnswer(testQueryPass, Row("2"))
+
+ val queryFail =
+ s"""
+ |SELECT REFLECT('java.lang.Integer', 'toHexString',"2");
+ |""".stripMargin
+ val typeException = intercept[ExtendedAnalysisException] {
+ sql(queryFail).collect()
+ }
+ assert(typeException.getErrorClass ===
"DATATYPE_MISMATCH.UNEXPECTED_STATIC_METHOD")
+ }
+
// TODO: Add more tests for other SQL expressions
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]