This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 667c0a9dbbe0 [SPARK-46791][SQL] Support Java Set in JavaTypeInference
667c0a9dbbe0 is described below
commit 667c0a9dbbe045c73842a345c1b3897b155564d4
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Jan 22 02:13:12 2024 -0800
[SPARK-46791][SQL] Support Java Set in JavaTypeInference
### What changes were proposed in this pull request?
This patch adds the support of Java `Set` as bean field in
`JavaTypeInference`.
### Why are the changes needed?
Scala `Set` (`scala.collection.Set`) is supported in `ScalaReflection` so
users can encode Scala `Set` in Dataset. But Java `Set` is not supported in
bean encoder (i.e., `JavaTypeInference`). This feature inconsistency makes Java
users cannot use `Set` like Scala users do.
### Does this PR introduce _any_ user-facing change?
Yes. Java `Set` is supported to be part of Java bean when encoding with
bean encoder.
### How was this patch tested?
Added tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44828 from viirya/java_set.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../spark/sql/catalyst/JavaTypeInference.scala | 6 ++-
.../sql/catalyst/expressions/objects/objects.scala | 50 ++++++++++++++++++++++
.../sql/catalyst/JavaTypeInferenceSuite.scala | 26 +++++++++--
.../expressions/ObjectExpressionsSuite.scala | 5 ++-
.../org/apache/spark/sql/JavaDatasetSuite.java | 45 +++++++++++++++++++
.../scala/org/apache/spark/sql/DatasetSuite.scala | 9 ++++
6 files changed, 136 insertions(+), 5 deletions(-)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index a945cb720b01..f85e96da2be1 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst
import java.beans.{Introspector, PropertyDescriptor}
import java.lang.reflect.{ParameterizedType, Type, TypeVariable}
-import java.util.{List => JList, Map => JMap}
+import java.util.{List => JList, Map => JMap, Set => JSet}
import javax.annotation.Nonnull
import scala.jdk.CollectionConverters._
@@ -112,6 +112,10 @@ object JavaTypeInference {
val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet,
typeVariables)
IterableEncoder(ClassTag(c), element, element.nullable,
lenientSerialization = false)
+ case c: Class[_] if classOf[JSet[_]].isAssignableFrom(c) =>
+ val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet,
typeVariables)
+ IterableEncoder(ClassTag(c), element, element.nullable,
lenientSerialization = false)
+
case c: Class[_] if classOf[JMap[_, _]].isAssignableFrom(c) =>
val keyEncoder = encoderFor(c.getTypeParameters.array(0), seenTypeSet,
typeVariables)
val valueEncoder = encoderFor(c.getTypeParameters.array(1), seenTypeSet,
typeVariables)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index bae2922cf921..a684ca18435e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -907,6 +907,8 @@ case class MapObjects private(
_.asInstanceOf[Array[_]].toImmutableArraySeq
case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
_.asInstanceOf[java.util.List[_]].asScala.toSeq
+ case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
+ _.asInstanceOf[java.util.Set[_]].asScala.toSeq
case ObjectType(cls) if cls == classOf[Object] =>
(inputCollection) => {
if (inputCollection.getClass.isArray) {
@@ -982,6 +984,34 @@ case class MapObjects private(
builder
}
}
+ case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
+ // Java set
+ if (cls == classOf[java.util.Set[_]] || cls ==
classOf[java.util.AbstractSet[_]]) {
+ // Specifying non concrete implementations of `java.util.Set`
+ executeFuncOnCollection(_).toSet.asJava
+ } else {
+ val constructors = cls.getConstructors()
+ val intParamConstructor = constructors.find { constructor =>
+ constructor.getParameterCount == 1 &&
constructor.getParameterTypes()(0) == classOf[Int]
+ }
+ val noParamConstructor = constructors.find { constructor =>
+ constructor.getParameterCount == 0
+ }
+
+ val constructor = intParamConstructor.map { intConstructor =>
+ (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
+ }.getOrElse {
+ (_: Int) => noParamConstructor.get.newInstance()
+ }
+
+ // Specifying concrete implementations of `java.util.Set`
+ (inputs) => {
+ val results = executeFuncOnCollection(inputs)
+ val builder =
constructor(inputs.length).asInstanceOf[java.util.Set[Any]]
+ results.foreach(builder.add(_))
+ builder
+ }
+ }
case None =>
// array
x => new GenericArrayData(executeFuncOnCollection(x).toArray)
@@ -1067,6 +1097,13 @@ case class MapObjects private(
s"java.util.Iterator $it = ${genInputData.value}.iterator();",
s"$it.next()"
)
+ case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls)
=>
+ val it = ctx.freshName("it")
+ (
+ s"${genInputData.value}.size()",
+ s"java.util.Iterator $it = ${genInputData.value}.iterator();",
+ s"$it.next()"
+ )
case ArrayType(et, _) =>
(
s"${genInputData.value}.numElements()",
@@ -1158,6 +1195,19 @@ case class MapObjects private(
(genValue: String) => s"$builder.add($genValue);",
s"$builder;"
)
+ case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
+ // Java set
+ val builder = ctx.freshName("collectionBuilder")
+ (
+ if (cls == classOf[java.util.Set[_]] || cls ==
classOf[java.util.AbstractSet[_]]) {
+ s"${cls.getName} $builder = new java.util.HashSet($dataLength);"
+ } else {
+ val param = Try(cls.getConstructor(Integer.TYPE)).map(_ =>
dataLength).getOrElse("")
+ s"${cls.getName} $builder = new ${cls.getName}($param);"
+ },
+ (genValue: String) => s"$builder.add($genValue);",
+ s"$builder;"
+ )
case _ =>
// array
(
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
index bef0cf8831eb..c785c71428ca 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst
import java.math.BigInteger
-import java.util.{LinkedList, List => JList, Map => JMap}
+import java.util.{HashSet, LinkedList, List => JList, Map => JMap, Set => JSet}
import scala.beans.{BeanProperty, BooleanBeanProperty}
import scala.reflect.{classTag, ClassTag}
@@ -37,6 +37,8 @@ class GenericCollectionBean {
@BeanProperty var listOfListOfStrings: JList[JList[String]] = _
@BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _
@BeanProperty var linkedListOfStrings: LinkedList[String] = _
+ @BeanProperty var hashSetOfString: HashSet[String] = _
+ @BeanProperty var setOfSetOfStrings: JSet[JSet[String]] = _
}
class LeafBean {
@@ -139,9 +141,16 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
assert(schema === expected)
}
- test("resolve type parameters for map and list") {
+ test("resolve type parameters for map, list and set") {
val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean])
val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]),
Seq(
+ encoderField(
+ "hashSetOfString",
+ IterableEncoder(
+ ClassTag(classOf[HashSet[_]]),
+ StringEncoder,
+ containsNull = true,
+ lenientSerialization = false)),
encoderField(
"linkedListOfStrings",
IterableEncoder(
@@ -166,7 +175,18 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
ClassTag(classOf[JMap[_, _]]),
StringEncoder,
expectedDummyBeanEncoder,
- valueContainsNull = true))))
+ valueContainsNull = true)),
+ encoderField(
+ "setOfSetOfStrings",
+ IterableEncoder(
+ ClassTag(classOf[JSet[_]]),
+ IterableEncoder(
+ ClassTag(classOf[JSet[_]]),
+ StringEncoder,
+ containsNull = true,
+ lenientSerialization = false),
+ containsNull = true,
+ lenientSerialization = false))))
assert(encoder === expected)
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 538a7600b02a..7f58516cf4eb 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -362,6 +362,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
assert(result.asInstanceOf[java.util.List[_]].asScala == expected)
+ case s if classOf[java.util.Set[_]].isAssignableFrom(s) =>
+ assert(result.asInstanceOf[java.util.Set[_]].asScala ==
expected.toSet)
case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) =>
assert(result == mutable.ArraySeq.make[Int](expected.toArray))
case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) =>
@@ -379,7 +381,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
classOf[java.util.AbstractSequentialList[Int]],
classOf[java.util.Vector[Int]],
- classOf[java.util.Stack[Int]], null)
+ classOf[java.util.Stack[Int]], null,
+ classOf[java.util.Set[Int]])
val list = new java.util.ArrayList[Int]()
list.add(1)
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 254c6df28209..bd776300bd5e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -111,6 +111,26 @@ public class JavaDatasetSuite implements Serializable {
Assertions.assertEquals(ds.schema(), ds2.schema());
}
+ @Test
+ public void testBeanWithSet() {
+ BeanWithSet bean = new BeanWithSet();
+ Set<Long> fields = asSet(1L, 2L, 3L);
+ bean.setFields(fields);
+ List<BeanWithSet> objects = Collections.singletonList(bean);
+
+ Dataset<BeanWithSet> ds = spark.createDataset(objects,
Encoders.bean(BeanWithSet.class));
+ Dataset<Row> df = ds.toDF();
+
+ Dataset<BeanWithSet> mapped =
+ df.map((MapFunction<Row, BeanWithSet>) row -> {
+ BeanWithSet obj = new BeanWithSet();
+ obj.setFields(new
HashSet<>(row.<Long>getList(row.fieldIndex("fields"))));
+ return obj;
+ }, Encoders.bean(BeanWithSet.class));
+
+ Assertions.assertEquals(objects, mapped.collectAsList());
+ }
+
@Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
@@ -1989,6 +2009,31 @@ public class JavaDatasetSuite implements Serializable {
Assertions.assertEquals(expected, df.collectAsList());
}
+ public static class BeanWithSet implements Serializable {
+ private Set<Long> fields;
+
+ public Set<Long> getFields() {
+ return fields;
+ }
+
+ public void setFields(Set<Long> fields) {
+ this.fields = fields;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ BeanWithSet that = (BeanWithSet) o;
+ return Objects.equal(fields, that.fields);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hashCode(fields);
+ }
+ }
+
public static class SpecificListsBean implements Serializable {
private ArrayList<Integer> arrayList;
private LinkedList<Integer> linkedList;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index cd28c60d83c7..f0f48026a4a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.sql.{Date, Timestamp}
+import scala.collection.immutable.HashSet
import scala.reflect.ClassTag
import scala.util.Random
@@ -2706,6 +2707,12 @@ class DatasetSuite extends QueryTest
assert(exception.context.head.asInstanceOf[DataFrameQueryContext].stackTrace.length
== 2)
}
}
+
+ test("SPARK-46791: Dataset with set field") {
+ val ds = Seq(WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar",
"zoo"))).toDS()
+ checkDataset(ds.map(t => t),
+ WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo")))
+ }
}
class DatasetLargeResultCollectingSuite extends QueryTest
@@ -2759,6 +2766,8 @@ case class WithImmutableMap(id: String, map_test:
scala.collection.immutable.Map
case class WithMap(id: String, map_test: scala.collection.Map[Long, String])
case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]])
+case class WithSet(id: Int, values: Set[String])
+
case class Generic[T](id: T, value: Double)
case class OtherTuple(_1: String, _2: Int)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]