This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4663b84eed19 [SPARK-47681][FOLLOWUP] Fix schema_of_variant for float
inputs
4663b84eed19 is described below
commit 4663b84eed1954b5df38e522e14eb2bee823f294
Author: Chenhao Li <[email protected]>
AuthorDate: Mon Jun 24 16:04:24 2024 +0800
[SPARK-47681][FOLLOWUP] Fix schema_of_variant for float inputs
### What changes were proposed in this pull request?
The current `schema_of_variant` depends on `JsonInferSchema.compatibleType`
to find the common type of two types. This function doesn't handle the case of
`float x decimal` or `decimal x float` correctly. It doesn't produce the
expected result, but consider the two types as incompatible and produces
`variant`. This change doesn't affect the JSON schema inference because it
never produces `float` beforehand.
### Why are the changes needed?
It is a bug fix and is required to process floats correctly.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
A new unit test that checks all type combinations.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47058 from chenhao-db/fix_schema_of_variant_float.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/json/JsonInferSchema.scala | 5 +++
.../variant/VariantExpressionSuite.scala | 47 ++++++++++++++++++++++
2 files changed, 52 insertions(+)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index 7ee522226e3e..d982e1f19da0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -372,6 +372,11 @@ object JsonInferSchema {
case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
DoubleType
+ // This branch is only used by `SchemaOfVariant.mergeSchema` because
`JsonInferSchema` never
+ // produces `FloatType`.
+ case (FloatType, _: DecimalType) | (_: DecimalType, FloatType) =>
+ DoubleType
+
case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision -
t2.scale)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 73abf8074e8c..a758fa84f6fc 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant
import java.time.{LocalDateTime, ZoneId, ZoneOffset}
+import scala.collection.mutable
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
@@ -860,4 +861,50 @@ class VariantExpressionSuite extends SparkFunSuite with
ExpressionEvalHelper {
StructType.fromDDL("c ARRAY<STRING>,b MAP<STRING, STRING>,a STRUCT<i:
INT>"))
check(struct,
"""{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""")
}
+
+ test("schema_of_variant - schema merge") {
+ val nul = Literal(null, StringType)
+ val boolean = Literal.default(BooleanType)
+ val long = Literal.default(LongType)
+ val string = Literal.default(StringType)
+ val double = Literal.default(DoubleType)
+ val date = Literal.default(DateType)
+ val timestamp = Literal.default(TimestampType)
+ val timestampNtz = Literal.default(TimestampNTZType)
+ val float = Literal.default(FloatType)
+ val binary = Literal.default(BinaryType)
+ val decimal = Literal(Decimal("123.456"), DecimalType(6, 3))
+ val array1 = Literal(Array(0L))
+ val array2 = Literal(Array(0.0))
+ val struct1 = Literal.default(StructType.fromDDL("a string"))
+ val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint"))
+ val inputs = Seq(nul, boolean, long, string, double, date, timestamp,
timestampNtz, float,
+ binary, decimal, array1, array2, struct1, struct2)
+
+ val results = mutable.HashMap.empty[(Literal, Literal), String]
+ for (i <- inputs) {
+ val inputType = if (i.value == null) "VOID" else i.dataType.sql
+ results.put((nul, i), inputType)
+ results.put((i, i), inputType)
+ }
+ results.put((long, double), "DOUBLE")
+ results.put((long, float), "FLOAT")
+ results.put((long, decimal), "DECIMAL(23,3)")
+ results.put((double, float), "DOUBLE")
+ results.put((double, decimal), "DOUBLE")
+ results.put((date, timestamp), "TIMESTAMP")
+ results.put((date, timestampNtz), "TIMESTAMP_NTZ")
+ results.put((timestamp, timestampNtz), "TIMESTAMP")
+ results.put((float, decimal), "DOUBLE")
+ results.put((array1, array2), "ARRAY<DOUBLE>")
+ results.put((struct1, struct2), "STRUCT<a: VARIANT, b: BIGINT>")
+
+ for (i1 <- inputs) {
+ for (i2 <- inputs) {
+ val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1),
"VARIANT"))
+ val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2,
VariantType)))
+ checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement,
s"ARRAY<$expected>")
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]