This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 87449c3f1d65 [SPARK-47563][SQL] Add map normalization on creation 87449c3f1d65 is described below commit 87449c3f1d65a430fec60981e364e552165fa075 Author: Stevo Mitric <stevo.mit...@databricks.com> AuthorDate: Wed Mar 27 18:56:26 2024 +0800 [SPARK-47563][SQL] Add map normalization on creation ### What changes were proposed in this pull request? Added normalization of map keys when they are put in `ArrayBasedMapBuilder`. ### Why are the changes needed? As map keys need to be unique, we need to add normalization on floating point numbers and prevent the following case when building a map: `Map(0.0, -0.0)`. This further unblocks GROUP BY statement for Map Types as per [this discussion](https://github.com/apache/spark/pull/45549#discussion_r1537803505). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? New UTs in `ArrayBasedMapBuilderSuite` ### Was this patch authored or co-authored using generative AI tooling? No Closes #45721 from stevomitric/stevomitric/fix-map-dup. Authored-by: Stevo Mitric <stevo.mit...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/util/ArrayBasedMapBuilder.scala | 14 +++++++++++--- .../catalyst/util/ArrayBasedMapBuilderSuite.scala | 22 +++++++++++++++++++++- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index d358c92dd62c..d13c3c6026a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -52,18 +53,25 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria private val mapKeyDedupPolicy = SQLConf.get.getConf(SQLConf.MAP_KEY_DEDUP_POLICY) + private lazy val keyNormalizer: Any => Any = keyType match { + case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER + case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER + case _ => identity + } + def put(key: Any, value: Any): Unit = { if (key == null) { throw QueryExecutionErrors.nullAsMapKeyNotAllowedError() } - val index = keyToIndex.getOrDefault(key, -1) + val keyNormalized = keyNormalizer(key) + val index = keyToIndex.getOrDefault(keyNormalized, -1) if (index == -1) { if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw QueryExecutionErrors.exceedMapSizeLimitError(size) } - keyToIndex.put(key, values.length) - keys.append(key) + keyToIndex.put(keyNormalized, values.length) + keys.append(keyNormalized) values.append(value) } else { if (mapKeyDedupPolicy == SQLConf.MapKeyDedupPolicy.EXCEPTION.toString) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index 5811f4cd4c85..3c8c49ee7fec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, DoubleType, IntegerType, StructType} import org.apache.spark.unsafe.Platform class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { @@ -60,6 +60,26 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { ) } + test("apply key normalization when creating") { + val builderDouble = new ArrayBasedMapBuilder(DoubleType, IntegerType) + builderDouble.put(-0.0, 1) + checkError( + exception = intercept[SparkRuntimeException](builderDouble.put(0.0, 2)), + errorClass = "DUPLICATED_MAP_KEY", + parameters = Map( + "key" -> "0.0", + "mapKeyDedupPolicy" -> "\"spark.sql.mapKeyDedupPolicy\"") + ) + } + + test("successful map normalization on build") { + val builder = new ArrayBasedMapBuilder(DoubleType, IntegerType) + builder.put(-0.0, 1) + val map = builder.build() + assert(map.numElements() == 1) + assert(ArrayBasedMapData.toScalaMap(map) == Map(0.0 -> 1)) + } + test("remove duplicated keys with last wins policy") { withSQLConf(SQLConf.MAP_KEY_DEDUP_POLICY.key -> SQLConf.MapKeyDedupPolicy.LAST_WIN.toString) { val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org