This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 4663b84eed19 [SPARK-47681][FOLLOWUP] Fix schema_of_variant for float 
inputs
4663b84eed19 is described below

commit 4663b84eed1954b5df38e522e14eb2bee823f294
Author: Chenhao Li <[email protected]>
AuthorDate: Mon Jun 24 16:04:24 2024 +0800

    [SPARK-47681][FOLLOWUP] Fix schema_of_variant for float inputs
    
    ### What changes were proposed in this pull request?
    
    The current `schema_of_variant` depends on `JsonInferSchema.compatibleType` 
to find the common type of two types. This function doesn't handle the case of 
`float x decimal` or `decimal x float` correctly. It doesn't produce the 
expected result, but consider the two types as incompatible and produces 
`variant`. This change doesn't affect the JSON schema inference because it 
never produces `float` beforehand.
    
    ### Why are the changes needed?
    
    It is a bug fix and is required to process floats correctly.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    A new unit test that checks all type combinations.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #47058 from chenhao-db/fix_schema_of_variant_float.
    
    Authored-by: Chenhao Li <[email protected]>
    Signed-off-by: Wenchen Fan <[email protected]>
---
 .../spark/sql/catalyst/json/JsonInferSchema.scala  |  5 +++
 .../variant/VariantExpressionSuite.scala           | 47 ++++++++++++++++++++++
 2 files changed, 52 insertions(+)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index 7ee522226e3e..d982e1f19da0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -372,6 +372,11 @@ object JsonInferSchema {
         case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) =>
           DoubleType
 
+        // This branch is only used by `SchemaOfVariant.mergeSchema` because 
`JsonInferSchema` never
+        // produces `FloatType`.
+        case (FloatType, _: DecimalType) | (_: DecimalType, FloatType) =>
+          DoubleType
+
         case (t1: DecimalType, t2: DecimalType) =>
           val scale = math.max(t1.scale, t2.scale)
           val range = math.max(t1.precision - t1.scale, t2.precision - 
t2.scale)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
index 73abf8074e8c..a758fa84f6fc 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.variant
 
 import java.time.{LocalDateTime, ZoneId, ZoneOffset}
 
+import scala.collection.mutable
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.{SparkFunSuite, SparkRuntimeException}
@@ -860,4 +861,50 @@ class VariantExpressionSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       StructType.fromDDL("c ARRAY<STRING>,b MAP<STRING, STRING>,a STRUCT<i: 
INT>"))
     check(struct, 
"""{"a":{"i":0},"b":{"a":"123","b":"true","c":"f"},"c":["123","true","f"]}""")
   }
+
+  test("schema_of_variant - schema merge") {
+    val nul = Literal(null, StringType)
+    val boolean = Literal.default(BooleanType)
+    val long = Literal.default(LongType)
+    val string = Literal.default(StringType)
+    val double = Literal.default(DoubleType)
+    val date = Literal.default(DateType)
+    val timestamp = Literal.default(TimestampType)
+    val timestampNtz = Literal.default(TimestampNTZType)
+    val float = Literal.default(FloatType)
+    val binary = Literal.default(BinaryType)
+    val decimal = Literal(Decimal("123.456"), DecimalType(6, 3))
+    val array1 = Literal(Array(0L))
+    val array2 = Literal(Array(0.0))
+    val struct1 = Literal.default(StructType.fromDDL("a string"))
+    val struct2 = Literal.default(StructType.fromDDL("a boolean, b bigint"))
+    val inputs = Seq(nul, boolean, long, string, double, date, timestamp, 
timestampNtz, float,
+      binary, decimal, array1, array2, struct1, struct2)
+
+    val results = mutable.HashMap.empty[(Literal, Literal), String]
+    for (i <- inputs) {
+      val inputType = if (i.value == null) "VOID" else i.dataType.sql
+      results.put((nul, i), inputType)
+      results.put((i, i), inputType)
+    }
+    results.put((long, double), "DOUBLE")
+    results.put((long, float), "FLOAT")
+    results.put((long, decimal), "DECIMAL(23,3)")
+    results.put((double, float), "DOUBLE")
+    results.put((double, decimal), "DOUBLE")
+    results.put((date, timestamp), "TIMESTAMP")
+    results.put((date, timestampNtz), "TIMESTAMP_NTZ")
+    results.put((timestamp, timestampNtz), "TIMESTAMP")
+    results.put((float, decimal), "DOUBLE")
+    results.put((array1, array2), "ARRAY<DOUBLE>")
+    results.put((struct1, struct2), "STRUCT<a: VARIANT, b: BIGINT>")
+
+    for (i1 <- inputs) {
+      for (i2 <- inputs) {
+        val expected = results.getOrElse((i1, i2), results.getOrElse((i2, i1), 
"VARIANT"))
+        val array = CreateArray(Seq(Cast(i1, VariantType), Cast(i2, 
VariantType)))
+        checkEvaluation(SchemaOfVariant(Cast(array, VariantType)).replacement, 
s"ARRAY<$expected>")
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to