This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 87449c3f1d65 [SPARK-47563][SQL] Add map normalization on creation
87449c3f1d65 is described below

commit 87449c3f1d65a430fec60981e364e552165fa075
Author: Stevo Mitric <stevo.mit...@databricks.com>
AuthorDate: Wed Mar 27 18:56:26 2024 +0800

    [SPARK-47563][SQL] Add map normalization on creation
    
    ### What changes were proposed in this pull request?
    Added normalization of map keys when they are put in `ArrayBasedMapBuilder`.
    
    ### Why are the changes needed?
    As map keys need to be unique, we need to add normalization on floating 
point numbers and prevent the following case when building a map: `Map(0.0, 
-0.0)`.
    This further unblocks GROUP BY statement for Map Types as per [this 
discussion](https://github.com/apache/spark/pull/45549#discussion_r1537803505).
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    New UTs in `ArrayBasedMapBuilderSuite`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #45721 from stevomitric/stevomitric/fix-map-dup.
    
    Authored-by: Stevo Mitric <stevo.mit...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/util/ArrayBasedMapBuilder.scala   | 14 +++++++++++---
 .../catalyst/util/ArrayBasedMapBuilderSuite.scala  | 22 +++++++++++++++++++++-
 2 files changed, 32 insertions(+), 4 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
index d358c92dd62c..d13c3c6026a2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -52,18 +53,25 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: 
DataType) extends Seria
 
   private val mapKeyDedupPolicy = 
SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY)
 
+  private lazy val keyNormalizer: Any => Any = keyType match {
+    case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER
+    case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER
+    case _ => identity
+  }
+
   def put(key: Any, value: Any): Unit = {
     if (key == null) {
       throw QueryExecutionErrors.nullAsMapKeyNotAllowedError()
     }
 
-    val index = keyToIndex.getOrDefault(key, -1)
+    val keyNormalized = keyNormalizer(key)
+    val index = keyToIndex.getOrDefault(keyNormalized, -1)
     if (index == -1) {
       if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
         throw QueryExecutionErrors.exceedMapSizeLimitError(size)
       }
-      keyToIndex.put(key, values.length)
-      keys.append(key)
+      keyToIndex.put(keyNormalized, values.length)
+      keys.append(keyNormalized)
       values.append(value)
     } else {
       if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) {
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala
index 5811f4cd4c85..3c8c49ee7fec 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.SQLHelper
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, 
StructType}
+import org.apache.spark.sql.types.{ArrayType, BinaryType, DoubleType, 
IntegerType, StructType}
 import org.apache.spark.unsafe.Platform
 
 class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper {
@@ -60,6 +60,26 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with 
SQLHelper {
     )
   }
 
+  test("apply key normalization when creating") {
+    val builderDouble = new ArrayBasedMapBuilder(DoubleType, IntegerType)
+    builderDouble.put(-0.0, 1)
+    checkError(
+      exception = intercept[SparkRuntimeException](builderDouble.put(0.0, 2)),
+      errorClass = "DUPLICATED_MAP_KEY",
+      parameters = Map(
+        "key" -> "0.0",
+        "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"")
+    )
+  }
+
+  test("successful map normalization on build") {
+    val builder = new ArrayBasedMapBuilder(DoubleType, IntegerType)
+    builder.put(-0.0, 1)
+    val map = builder.build()
+    assert(map.numElements() == 1)
+    assert(ArrayBasedMapData.toScalaMap(map) == Map(0.0 -> 1))
+  }
+
   test("remove duplicated keys with last wins policy") {
     withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> 
SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) {
       val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to