This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push:
new a7480e647fe [SPARK-38823][SQL] Make `NewInstance` non-foldable to fix
aggregation buffer corruption issue
a7480e647fe is described below
commit a7480e647fe1ed930c0cd2ad1679b3685a675d02
Author: Bruce Robbins <[email protected]>
AuthorDate: Fri Apr 15 08:37:43 2022 +0900
[SPARK-38823][SQL] Make `NewInstance` non-foldable to fix aggregation
buffer corruption issue
### What changes were proposed in this pull request?
Make `NewInstance` non-foldable.
### Why are the changes needed?
When handling Java beans as input, Spark creates `NewInstance` with no
arguments. On master and 3.3, `NewInstance` with no arguments is considered
foldable. As a result, the `ConstantFolding` rule converts `NewInstance` into a
`Literal` holding an instance of the user's specified Java bean. The instance
becomes a singleton that gets reused for each input record (although its fields
get updated by `InitializeJavaBean`).
Because the instance gets reused, sometimes multiple buffers in
`AggregationIterator` are actually referring to the same Java bean instance.
Take, for example, the test I added in this PR, or the `spark-shell`
example I added to SPARK-38823 as a comment.
The input is:
```
new Item("a", 1),
new Item("b", 3),
new Item("c", 2),
new Item("a", 7)
```
As `ObjectAggregationIterator` reads the input, the buffers get set up as
follows (note that the first field of Item should be the same as the key):
```
- Read Item("a", 1)
- Buffers are now:
Key "a" --> Item("a", 1)
- Read Item("b", 3)
- Buffers are now:
Key "a" -> Item("b", 3)
Key "b" -> Item("b", 3)
```
The buffer for key "a" now contains `Item("b", 3)`. That's because both
buffers contain a reference to the same Item instance, and that Item instance's
fields were updated when `Item("b", 3)` was read.
This PR makes `NewInstance` non-foldable, so it will not get optimized
away, thus ensuring a new instance of the Java bean for each input record.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New unit test.
Closes #36183 from bersprockets/newinstance_issue.
Authored-by: Bruce Robbins <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit cc7cb7a803d5de03c526480c8968bbb2c3e82484)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../sql/catalyst/expressions/objects/objects.scala | 3 +
.../catalyst/optimizer/ConstantFoldingSuite.scala | 15 +---
.../spark/sql/JavaBeanDeserializationSuite.java | 93 ++++++++++++++++++++++
3 files changed, 99 insertions(+), 12 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 2c879beeed6..fe982b23829 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -516,6 +516,9 @@ case class NewInstance(
override def nullable: Boolean = needNullCheck
+ // Non-foldable to prevent the optimizer from replacing NewInstance with a
singleton instance
+ // of the specified class.
+ override def foldable: Boolean = false
override def children: Seq[Expression] = arguments
final override val nodePatterns: Seq[TreePattern] = Seq(NEW_INSTANCE)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index a2ee2a2fb68..b06e001e412 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -21,11 +21,10 @@ import
org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, Unresol
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance,
StaticInvoke}
+import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, StaticInvoke}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.ByteArray
@@ -318,14 +317,7 @@ class ConstantFoldingSuite extends PlanTest {
Literal.create("a", StringType),
"substring",
StringType,
- Seq(Literal(0), Literal(1))).as("c2"),
- NewInstance(
- cls = classOf[GenericArrayData],
- arguments = Literal.fromObject(List(1, 2, 3)) :: Nil,
- inputTypes = Nil,
- propagateNull = false,
- dataType = ArrayType(IntegerType),
- outerPointer = None).as("c3"))
+ Seq(Literal(0), Literal(1))).as("c2"))
val optimized = Optimize.execute(originalQuery.analyze)
@@ -333,8 +325,7 @@ class ConstantFoldingSuite extends PlanTest {
testRelation
.select(
Literal("WWSpark".getBytes()).as("c1"),
- Literal.create("a", StringType).as("c2"),
- Literal.create(new GenericArrayData(List(1, 2, 3)),
ArrayType(IntegerType)).as("c3"))
+ Literal.create("a", StringType).as("c2"))
.analyze
comparePlans(optimized, correctAnswer)
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
index af0a22b0360..06a5c50b30c 100644
---
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
+++
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java
@@ -26,6 +26,8 @@ import java.util.*;
import org.apache.commons.lang3.builder.ToStringBuilder;
import org.apache.commons.lang3.builder.ToStringStyle;
+import org.apache.spark.api.java.function.MapFunction;
+import org.apache.spark.api.java.function.ReduceFunction;
import org.junit.*;
import org.apache.spark.sql.*;
@@ -37,6 +39,7 @@ import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.test.TestSparkSession;
+import scala.Tuple2;
public class JavaBeanDeserializationSuite implements Serializable {
@@ -562,6 +565,96 @@ public class JavaBeanDeserializationSuite implements
Serializable {
}
}
+ @Test
+ public void testSPARK38823NoBeanReuse() {
+ List<Item> items = Arrays.asList(
+ new Item("a", 1),
+ new Item("b", 3),
+ new Item("c", 2),
+ new Item("a", 7));
+
+ Encoder<Item> encoder = Encoders.bean(Item.class);
+
+ Dataset<Item> ds = spark.createDataFrame(items, Item.class)
+ .as(encoder)
+ .coalesce(1);
+
+ MapFunction<Item, String> mf = new MapFunction<Item, String>() {
+ @Override
+ public String call(Item item) throws Exception {
+ return item.getK();
+ }
+ };
+
+ ReduceFunction<Item> rf = new ReduceFunction<Item>() {
+ @Override
+ public Item call(Item item1, Item item2) throws Exception {
+ Assert.assertNotSame(item1, item2);
+ return item1.addValue(item2.getV());
+ }
+ };
+
+ Dataset<Tuple2<String, Item>> finalDs = ds
+ .groupByKey(mf, Encoders.STRING())
+ .reduceGroups(rf);
+
+ List<Tuple2<String, Item>> expectedRecords = Arrays.asList(
+ new Tuple2("a", new Item("a", 8)),
+ new Tuple2("b", new Item("b", 3)),
+ new Tuple2("c", new Item("c", 2)));
+
+ List<Tuple2<String, Item>> result = finalDs.collectAsList();
+
+ Assert.assertEquals(expectedRecords, result);
+ }
+
+ public static class Item implements Serializable {
+ private String k;
+ private int v;
+
+ public String getK() {
+ return k;
+ }
+
+ public int getV() {
+ return v;
+ }
+
+ public void setK(String k) {
+ this.k = k;
+ }
+
+ public void setV(int v) {
+ this.v = v;
+ }
+
+ public Item() { }
+
+ public Item(String k, int v) {
+ this.k = k;
+ this.v = v;
+ }
+
+ public Item addValue(int inc) {
+ return new Item(k, v + inc);
+ }
+
+ public String toString() {
+ return "Item(" + k + "," + v + ")";
+ }
+
+ public boolean equals(Object o) {
+ if (!(o instanceof Item)) {
+ return false;
+ }
+ Item other = (Item) o;
+ if (other.getK().equals(k) && other.getV() == v) {
+ return true;
+ }
+ return false;
+ }
+ }
+
public static final class LocalDateInstantRecord {
private String localDateField;
private String instantField;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]