This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new 505d5f0d538d [SPARK-54339][SQL] Fix AttributeMap non-determinism
505d5f0d538d is described below
commit 505d5f0d538dd62a0aaaf09ec408e670d2dc7b4f
Author: Kelvin Jiang <[email protected]>
AuthorDate: Tue Nov 18 11:40:44 2025 +0800
[SPARK-54339][SQL] Fix AttributeMap non-determinism
### What changes were proposed in this pull request?
This PR fixes the `+`, `updated`, and `removed` methods of `AttributeMap`
to correctly hash with `Attribute.ExprId` instead of `Attribute` as a whole.
### Why are the changes needed?
This change fixes non-determinism with the `AttributeMap` when an entry is
being added to the `AttributeMap` with `+` such that `attr1 != attr2` but
`attr1.exprId = attr2.exprId`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added a new test suite.
### Was this patch authored or co-authored using generative AI tooling?
Tests were generated by Claude Code on Sonnet 4.5.
Closes #53044 from kelvinjian-db/fix-attributemap.
Authored-by: Kelvin Jiang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 78d1d52601e4c43bc4c23f543ffb411416f1f6cd)
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/catalyst/expressions/AttributeMap.scala | 6 +-
.../catalyst/expressions/AttributeMapSuite.scala | 278 +++++++++++++++++++++
2 files changed, 281 insertions(+), 3 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
index b317cacc061b..9b6430c9ff0f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala
@@ -48,14 +48,14 @@ class AttributeMap[A](val baseMap: Map[ExprId, (Attribute,
A)])
override def contains(k: Attribute): Boolean = get(k).isDefined
override def + [B1 >: A](kv: (Attribute, B1)): AttributeMap[B1] =
- AttributeMap(baseMap.values.toMap + kv)
+ new AttributeMap(baseMap + (kv._1.exprId -> kv))
override def updated[B1 >: A](key: Attribute, value: B1): Map[Attribute, B1]
=
- baseMap.values.toMap + (key -> value)
+ this + (key -> value)
override def iterator: Iterator[(Attribute, A)] = baseMap.valuesIterator
- override def removed(key: Attribute): Map[Attribute, A] =
baseMap.values.toMap - key
+ override def removed(key: Attribute): Map[Attribute, A] = new
AttributeMap(baseMap - key.exprId)
def ++(other: AttributeMap[A]): AttributeMap[A] = new AttributeMap(baseMap
++ other.baseMap)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeMapSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeMapSuite.scala
new file mode 100644
index 000000000000..fbb37d452437
--- /dev/null
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/AttributeMapSuite.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType}
+
+class AttributeMapSuite extends SparkFunSuite {
+
+ val aUpper = AttributeReference("A", IntegerType)(exprId = ExprId(1))
+ val aLower = AttributeReference("a", IntegerType)(exprId = ExprId(1))
+ val fakeA = AttributeReference("a", IntegerType)(exprId = ExprId(3))
+
+ val bUpper = AttributeReference("B", IntegerType)(exprId = ExprId(2))
+ val bLower = AttributeReference("b", IntegerType)(exprId = ExprId(2))
+
+ val cAttr = AttributeReference("c", StringType)(exprId = ExprId(4))
+
+ test("basic map operations - get") {
+ val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+
+ // Should find by exprId, not by attribute equality
+ assert(map.get(aLower) === Some("value1"))
+ assert(map.get(aUpper) === Some("value1"))
+ assert(map.get(bLower) === Some("value2"))
+ assert(map.get(bUpper) === Some("value2"))
+
+ // Different exprId should not be found
+ assert(map.get(fakeA) === None)
+ }
+
+ test("basic map operations - contains") {
+ val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+
+ // Should find by exprId, not by attribute equality
+ assert(map.contains(aLower))
+ assert(map.contains(aUpper))
+ assert(map.contains(bUpper))
+ assert(!map.contains(fakeA))
+ }
+
+ test("basic map operations - getOrElse") {
+ val map = AttributeMap(Seq((aUpper, "value1")))
+
+ assert(map.getOrElse(aLower, "default") === "value1")
+ assert(map.getOrElse(fakeA, "default") === "default")
+ }
+
+ test("+ operator preserves ExprId-based hashing") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = map1 + (bUpper -> "value2")
+
+ // The resulting map should still be an AttributeMap
+ assert(map2.isInstanceOf[AttributeMap[_]])
+
+ // Should look up by exprId, not by attribute equality
+ assert(map2.get(aLower) === Some("value1"))
+ assert(map2.get(bLower) === Some("value2"))
+ }
+
+ test("+ operator with attribute having different metadata") {
+ val metadata1 = new MetadataBuilder().putString("key", "value1").build()
+ val metadata2 = new MetadataBuilder().putString("key", "value2").build()
+
+ // Two attributes with same exprId but different metadata
+ val attrWithMetadata1 = AttributeReference("col", IntegerType, metadata =
metadata1)(
+ exprId = ExprId(100))
+ val attrWithMetadata2 = AttributeReference("col", IntegerType, metadata =
metadata2)(
+ exprId = ExprId(100))
+
+ // These should have different hashCodes but same exprId
+ assert(attrWithMetadata1.hashCode() != attrWithMetadata2.hashCode(),
+ "Attributes with different metadata should have different hashCodes")
+ assert(attrWithMetadata1.exprId == attrWithMetadata2.exprId,
+ "Attributes should have the same exprId")
+
+ // Create a map with the first attribute
+ val map1 = AttributeMap(Seq((attrWithMetadata1, "original")))
+
+ // Add an entry using the + operator
+ val map2 = map1 + (cAttr -> "new")
+
+ // CRITICAL: The map should still find the original entry using an
attribute
+ // with the same exprId but different metadata
+ assert(map2.get(attrWithMetadata2) === Some("original"),
+ "AttributeMap should look up by exprId, not by attribute hashCode")
+
+ // And the new entry should also be present
+ assert(map2.get(cAttr) === Some("new"))
+ }
+
+ test("+ operator updates existing key") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = map1 + (aLower -> "updated")
+
+ // Since aLower has the same exprId as aUpper, it should update the value
+ assert(map2.get(aUpper) === Some("updated"))
+ assert(map2.get(aLower) === Some("updated"))
+ assert(map2.size === 1)
+ }
+
+ test("+ operator with type widening") {
+ val map1: AttributeMap[String] = AttributeMap(Seq((aUpper, "value1")))
+ val map2: AttributeMap[Any] = map1 + (bUpper -> 42)
+
+ assert(map2.get(aUpper) === Some("value1"))
+ assert(map2.get(bUpper) === Some(42))
+ }
+
+ test("++ operator preserves AttributeMap semantics") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = AttributeMap(Seq((bUpper, "value2")))
+ val combined = map1 ++ map2
+
+ assert(combined.isInstanceOf[AttributeMap[_]])
+ assert(combined.get(aLower) === Some("value1"))
+ assert(combined.get(bLower) === Some("value2"))
+ }
+
+ test("updated method") {
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ val map2 = map1.updated(bUpper, "value2")
+
+ // Note: updated returns a Map[Attribute, B1], not AttributeMap
+ assert(map2.get(aUpper) === Some("value1"))
+ assert(map2.get(bUpper) === Some("value2"))
+ }
+
+ test("- operator (removal)") {
+ val map1 = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+ val map2 = map1 - aLower
+
+ // Note: - returns a Map[Attribute, A], not AttributeMap
+ assert(map2.get(aUpper) === None)
+ assert(map2.get(bUpper) === Some("value2"))
+ }
+
+ test("iterator") {
+ val map = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+ val entries = map.iterator.toSeq
+
+ assert(entries.size === 2)
+ assert(entries.contains((aUpper, "value1")))
+ assert(entries.contains((bUpper, "value2")))
+ }
+
+ test("size") {
+ val emptyMap = AttributeMap.empty[String]
+ assert(emptyMap.size === 0)
+
+ val map1 = AttributeMap(Seq((aUpper, "value1")))
+ assert(map1.size === 1)
+
+ val map2 = AttributeMap(Seq((aUpper, "value1"), (bUpper, "value2")))
+ assert(map2.size === 2)
+ }
+
+ test("empty map") {
+ val emptyMap = AttributeMap.empty[String]
+ assert(emptyMap.get(aUpper) === None)
+ assert(emptyMap.size === 0)
+ assert(!emptyMap.contains(aUpper))
+ }
+
+ test("duplicate keys in construction") {
+ // When constructing with duplicate exprIds, the last one should win
+ val map = AttributeMap(Seq((aUpper, "value1"), (aLower, "value2")))
+ assert(map.size === 1)
+ assert(map.get(aUpper) === Some("value2"))
+ }
+
+ test("map construction from Map") {
+ val regularMap = Map(aUpper -> "value1", bUpper -> "value2")
+ val attrMap = AttributeMap(regularMap)
+
+ assert(attrMap.get(aLower) === Some("value1"))
+ assert(attrMap.get(bLower) === Some("value2"))
+ }
+
+ test("chained + operations") {
+ val map = AttributeMap.empty[String] + (aUpper -> "value1") + (bUpper ->
"value2") +
+ (cAttr -> "value3")
+
+ assert(map.size === 3)
+ assert(map.get(aLower) === Some("value1"))
+ assert(map.get(bLower) === Some("value2"))
+ assert(map.get(cAttr) === Some("value3"))
+ }
+
+ test("+ should be deterministic with Attributes with diff hashcodes and same
exprId") {
+ // The HashMap needs to be of a certain size before the hashing starts to
collide, set up
+ // these AttributeMaps to start with size 5.
+ var map1 = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ var map2 = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ val qualifier1 = Seq("d")
+ val qualifier2 = Seq()
+ val exprId = ExprId(42)
+ val attr1 = AttributeReference("x", IntegerType)(exprId = exprId,
qualifier = qualifier1)
+ val attr2 = AttributeReference("x", IntegerType)(exprId = exprId,
qualifier = qualifier2)
+ assert(attr1.hashCode != attr2.hashCode)
+
+ map1 = map1 + (attr1 -> 100)
+ map1 = map1 + (attr2 -> 200)
+ assert(map1.get(attr2) === Some(200))
+
+ map2 = map2 + (attr2 -> 200)
+ map2 = map2 + (attr1 -> 100)
+ assert(map2.get(attr2) === Some(100))
+ }
+
+ test("updated should be deterministic with Attributes with diff hashcodes
and same exprId") {
+ // The HashMap needs to be of a certain size before the hashing starts to
collide, set up
+ // these AttributeMaps to start with size 5.
+ var map1: Map[Attribute, Int] = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ var map2: Map[Attribute, Int] = AttributeMap(
+ Seq(
+ AttributeReference("a", IntegerType)(exprId = ExprId(1)) -> 1,
+ AttributeReference("b", IntegerType)(exprId = ExprId(2)) -> 2,
+ AttributeReference("c", IntegerType)(exprId = ExprId(3)) -> 3,
+ AttributeReference("d", IntegerType)(exprId = ExprId(4)) -> 4,
+ AttributeReference("e", IntegerType)(exprId = ExprId(5)) -> 5
+ )
+ )
+ val qualifier1 = Seq("d")
+ val qualifier2 = Seq()
+ val exprId = ExprId(42)
+ val attr1 = AttributeReference("x", IntegerType)(exprId = exprId,
qualifier = qualifier1)
+ val attr2 = AttributeReference("x", IntegerType)(exprId = exprId,
qualifier = qualifier2)
+ assert(attr1.hashCode != attr2.hashCode)
+
+ map1 = map1.updated(attr1, 100)
+ map1 = map1.updated(attr2, 200)
+ assert(map1.get(attr2) === Some(200))
+
+ map2 = map2.updated(attr2, 200)
+ map2 = map2.updated(attr1, 100)
+ assert(map2.get(attr2) === Some(100))
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]