sririshindra commented on code in PR #48252:
URL: https://github.com/apache/spark/pull/48252#discussion_r1844241460
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
// TODO: we should only collect properties that have getter and setter.
However, some tests
// pass in scala case class as java bean class which doesn't have
getter and setter.
val properties = getJavaBeanReadableProperties(c)
- // add type variables from inheritance hierarchy of the class
- val classTV = JavaTypeUtils.getTypeArguments(c,
classOf[Object]).asScala.toMap ++
- typeVariables
- // Note that the fields are ordered by name.
- val fields = properties.map { property =>
- val readMethod = property.getReadMethod
- val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet
+ c, classTV)
- // The existence of `javax.annotation.Nonnull`, means this field is
not nullable.
- val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
- EncoderField(
- property.getName,
- encoder,
- encoder.nullable && !hasNonNull,
- Metadata.empty,
- Option(readMethod.getName),
- Option(property.getWriteMethod).map(_.getName))
+
+ // if the properties is empty and this is not a top level enclosing
class, then we
+ // should not consider class as bean, as otherwise it will be treated as
empty schema
+ // and loose the data on deser.
+ if (properties.isEmpty && seenTypeSet.nonEmpty) {
+ findBestEncoder(Seq(c), seenTypeSet, typeVariables, None,
serializableEncodersOnly = true)
+ .getOrElse(throw
ExecutionErrors.cannotFindEncoderForTypeError(t.getTypeName))
+ } else {
+ // add type variables from inheritance hierarchy of the class
+ val parentClassesTypeMap =
+ JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap
+ val classTV = parentClassesTypeMap ++ typeVariables
+ // Note that the fields are ordered by name.
+ val fields = properties.map { property =>
+ val readMethod = property.getReadMethod
+ val methodReturnType = readMethod.getGenericReturnType
+ val encoder = encoderFor(methodReturnType, seenTypeSet + c, classTV)
+ // The existence of `javax.annotation.Nonnull`, means this field is
not nullable.
+ val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
+ EncoderField(
+ property.getName,
+ encoder,
+ encoder.nullable && !hasNonNull,
+ Metadata.empty,
+ Option(readMethod.getName),
+ Option(property.getWriteMethod).map(_.getName))
+ }
+ // implies it cannot be assumed a BeanClass.
+ // Check if its super class or interface could be represented by an
Encoder
+
+ JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq)
}
- JavaBeanEncoder(ClassTag(c), fields.toImmutableArraySeq)
case _ =>
throw ExecutionErrors.cannotFindEncoderForTypeError(t.toString)
}
+ private def createUDTEncoderUsingAnnotation(c: Class[_]): UDTEncoder[Any] = {
+ val udt = c
+ .getAnnotation(classOf[SQLUserDefinedType])
+ .udt()
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[Any]]
+ val udtClass =
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()
+ UDTEncoder(udt, udtClass)
+ }
+
+ private def createUDTEncoderUsingRegistration(c: Class[_]): UDTEncoder[Any]
= {
+ val udt = UDTRegistration
+ .getUDTFor(c.getName)
+ .get
+ .getConstructor()
+ .newInstance()
+ .asInstanceOf[UserDefinedType[Any]]
+ UDTEncoder(udt, udt.getClass)
+ }
+
def getJavaBeanReadableProperties(beanClass: Class[_]):
Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
beanInfo.getPropertyDescriptors
.filterNot(_.getName == "class")
.filterNot(_.getName == "declaringClass")
.filter(_.getReadMethod != null)
}
+
+ private def findBestEncoder(
+ typesToCheck: Seq[Class[_]],
+ seenTypeSet: Set[Class[_]],
+ typeVariables: Map[TypeVariable[_], Type],
+ baseClass: Option[Class[_]],
+ serializableEncodersOnly: Boolean = false): Option[AgnosticEncoder[_]] =
+ if (serializableEncodersOnly) {
+ val isClientConnect = clientConnectFlag.get
+ assert(typesToCheck.size == 1)
+ typesToCheck
+ .flatMap(c => {
+ if (!isClientConnect &&
classOf[KryoSerializable].isAssignableFrom(c)) {
Review Comment:
Can we add a comment explaining why isClientConnect being true disqualifies
the type to be encodes using kryo Serializer. Is there a chance that will
change in the future? If so, can we add a TODO statement here so that we can
remove this condition if and when KryoSerization is avaialble with Spark
Connect.
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
// TODO: we should only collect properties that have getter and setter.
However, some tests
// pass in scala case class as java bean class which doesn't have
getter and setter.
val properties = getJavaBeanReadableProperties(c)
- // add type variables from inheritance hierarchy of the class
- val classTV = JavaTypeUtils.getTypeArguments(c,
classOf[Object]).asScala.toMap ++
- typeVariables
- // Note that the fields are ordered by name.
- val fields = properties.map { property =>
- val readMethod = property.getReadMethod
- val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet
+ c, classTV)
- // The existence of `javax.annotation.Nonnull`, means this field is
not nullable.
- val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
- EncoderField(
- property.getName,
- encoder,
- encoder.nullable && !hasNonNull,
- Metadata.empty,
- Option(readMethod.getName),
- Option(property.getWriteMethod).map(_.getName))
+
+ // if the properties is empty and this is not a top level enclosing
class, then we
+ // should not consider class as bean, as otherwise it will be treated as
empty schema
+ // and loose the data on deser.
+ if (properties.isEmpty && seenTypeSet.nonEmpty) {
+ findBestEncoder(Seq(c), seenTypeSet, typeVariables, None,
serializableEncodersOnly = true)
Review Comment:
Could you please elobrate on why the serializableEncodersOnly flag is needed
here? Maybe adding a comment for that makes it more clear?
##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2909,6 +3016,7 @@ object KryoData {
/** Used to test Java encoder. */
class JavaData(val a: Int) extends Serializable {
+ def this() = this(0)
Review Comment:
It is not clear to me my why the method 'def this()' needs to be explicitly
defnined. I don't see it being used anywhere. Can we remove it or Could you
please explain why this is here?
##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2802,6 +2821,79 @@ class DatasetSuite extends QueryTest
}
}
}
+
+ test("SPARK-49789 Bean class encoding with generic type implementing
Serializable") {
+ // just create encoder
+ val enc = Encoders.bean(classOf[MessageWrapper[_]])
+ val data = Seq("test1", "test2").map(str => {
+ val msg = new MessageWrapper[String]()
+ msg.setMessage(str)
+ msg
+ })
+ validateParamBeanDataset(classOf[MessageWrapper[String]],
+ data, mutable.Buffer(data: _*),
+ StructType(Seq(StructField("message", BinaryType, true)))
+ )
+ }
+
+ test("SPARK-49789 Bean class encoding with generic type indirectly
extending" +
+ " Serializable class") {
+ // just create encoder
+ Encoders.bean(classOf[BigDecimalMessageWrapper[_]])
+ val data = Seq(2d, 8d).map(doub => {
+ val bean = new BigDecimalMessageWrapper[DerivedBigDecimalExtender]()
+ bean.setMessage(new DerivedBigDecimalExtender(doub))
+ bean
+ })
+ validateParamBeanDataset(
+ classOf[BigDecimalMessageWrapper[DerivedBigDecimalExtender]],
+ data, mutable.Buffer(data: _*),
+ StructType(Seq(StructField("message", BinaryType, true))))
+ }
+
+ test("SPARK-49789. test bean class with generictype bound of UDTType") {
+ // just create encoder
+ UDTRegistration.register(classOf[TestUDT].getName,
classOf[TestUDTType].getName)
+ val enc = Encoders.bean(classOf[UDTBean[_]])
+ val baseData = Seq((1, "a"), (2, "b"))
+ val data = baseData.map(tup => {
+ val bean = new UDTBean[TestUDT]()
+ bean.setMessage(new TestUDTImplSub(tup._1, tup._2))
+ bean
+ })
+ val expectedData = baseData.map(tup => {
+ val bean = new UDTBean[TestUDT]()
+ bean.setMessage(new TestUDTImpl(tup._1, tup._2))
+ bean
+ })
+ validateParamBeanDataset(
+ classOf[UDTBean[TestUDT]],
+ data, mutable.Buffer(expectedData: _*),
+ StructType(Seq(StructField("message", new TestUDTType(), true))))
+ }
+
+ private def validateParamBeanDataset[T](
+ classToEncode: Class[T],
+ data: Seq[T],
+ expectedData: mutable.Buffer[T],
+ expectedSchema: StructType): Unit =
{
+
Review Comment:
Indentation seems to be off. Could you please fix that.
##########
connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala:
##########
@@ -137,8 +137,16 @@ class SparkSession private[sql] (
/** @inheritdoc */
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame
= {
- val encoder =
JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
- createDataset(encoder, data.iterator().asScala).toDF()
+ JavaTypeInference.setSparkClientFlag()
+ val encoderTry = Try {
+ JavaTypeInference.encoderFor(beanClass.asInstanceOf[Class[Any]])
+ }
+ JavaTypeInference.unsetSparkClientFlag()
Review Comment:
Can we add a comment explaining exactly why we are setting and unsetting the
Spark Client Flag here?
My understanding is that based on if we use a regular spark Session or a
spark session created from Spark connect makes a difference in terms of what we
can infer. But the exact reason why there is a difference between regular
SparkSesson and SparkSession correspoding to connect and what it is not super
clear to me. Can we please document that here or where the connectClient Field
is initialized.
##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2945,3 +3053,132 @@ case class SaveModeArrayCase(modes: Array[SaveMode])
case class K1(a: Long)
case class K2(a: Long, b: Long)
+
+class MessageWrapper[T <: java.io.Serializable] extends java.io.Serializable {
+ private var message: T = _
+
+ def getMessage: T = message
+
+ def setMessage(message: T): Unit = {
+ this.message = message
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: MessageWrapper[_] => m.message == this.message
+
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalMessageWrapper[T <: BigDecimalExtender] extends
java.io.Serializable {
+ private var message: T = _
+
+ def getMessage: T = message
+
+ def setMessage(message: T): Unit = {
+ this.message = message
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: BigDecimalMessageWrapper[_] => m.message == this.message
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalExtender(doub: Double) extends java.math.BigDecimal(doub) {
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: BigDecimalExtender =>
super.equals(m.asInstanceOf[java.math.BigDecimal])
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+class DerivedBigDecimalExtender(doub: Double) extends BigDecimalExtender(doub)
{
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: DerivedBigDecimalExtender =>
super.equals(m.asInstanceOf[BigDecimalExtender])
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+trait TestUDT extends Serializable {
+ def intField: Int
+
+ def stringField: String
+}
+
+class TestUDTImpl(var intF: Int, var stringF: String) extends TestUDT {
+ def this() = this(0, "")
+
+ override def intField: Int = intF
+
+ override def stringField: String = stringF
+
+ override def hashCode(): Int = intF.hashCode() + stringF.hashCode
+
+ override def equals(obj: Any): Boolean = obj match {
+ case b: TestUDT => b.intField == this.intField && b.stringField ==
this.stringField
+
+ case _ => false
+ }
+}
+
+class TestUDTImplSub(var iF: Int, var sF: String) extends TestUDTImpl(iF, sF) {
+ def this() = this(0, "")
Review Comment:
Ditto
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
// TODO: we should only collect properties that have getter and setter.
However, some tests
// pass in scala case class as java bean class which doesn't have
getter and setter.
val properties = getJavaBeanReadableProperties(c)
- // add type variables from inheritance hierarchy of the class
- val classTV = JavaTypeUtils.getTypeArguments(c,
classOf[Object]).asScala.toMap ++
- typeVariables
- // Note that the fields are ordered by name.
- val fields = properties.map { property =>
- val readMethod = property.getReadMethod
- val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet
+ c, classTV)
- // The existence of `javax.annotation.Nonnull`, means this field is
not nullable.
- val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
- EncoderField(
- property.getName,
- encoder,
- encoder.nullable && !hasNonNull,
- Metadata.empty,
- Option(readMethod.getName),
- Option(property.getWriteMethod).map(_.getName))
+
+ // if the properties is empty and this is not a top level enclosing
class, then we
+ // should not consider class as bean, as otherwise it will be treated as
empty schema
+ // and loose the data on deser.
Review Comment:
nit: Can you rename deser as 'deserialization' for more clarity.
Also, I did not quite understand why if 'this is not a top lecvel enclosing
class' 'it will be treated as empty schema'. What does that mean exactly?
Could you please elobrate that.
##########
sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala:
##########
@@ -2945,3 +3053,132 @@ case class SaveModeArrayCase(modes: Array[SaveMode])
case class K1(a: Long)
case class K2(a: Long, b: Long)
+
+class MessageWrapper[T <: java.io.Serializable] extends java.io.Serializable {
+ private var message: T = _
+
+ def getMessage: T = message
+
+ def setMessage(message: T): Unit = {
+ this.message = message
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: MessageWrapper[_] => m.message == this.message
+
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalMessageWrapper[T <: BigDecimalExtender] extends
java.io.Serializable {
+ private var message: T = _
+
+ def getMessage: T = message
+
+ def setMessage(message: T): Unit = {
+ this.message = message
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: BigDecimalMessageWrapper[_] => m.message == this.message
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = this.message.hashCode()
+}
+
+class BigDecimalExtender(doub: Double) extends java.math.BigDecimal(doub) {
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: BigDecimalExtender =>
super.equals(m.asInstanceOf[java.math.BigDecimal])
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+class DerivedBigDecimalExtender(doub: Double) extends BigDecimalExtender(doub)
{
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case m: DerivedBigDecimalExtender =>
super.equals(m.asInstanceOf[BigDecimalExtender])
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = super.hashCode()
+}
+
+trait TestUDT extends Serializable {
+ def intField: Int
+
+ def stringField: String
+}
+
+class TestUDTImpl(var intF: Int, var stringF: String) extends TestUDT {
+ def this() = this(0, "")
Review Comment:
Ditto
##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala:
##########
@@ -148,34 +163,180 @@ object JavaTypeInference {
// TODO: we should only collect properties that have getter and setter.
However, some tests
// pass in scala case class as java bean class which doesn't have
getter and setter.
val properties = getJavaBeanReadableProperties(c)
- // add type variables from inheritance hierarchy of the class
- val classTV = JavaTypeUtils.getTypeArguments(c,
classOf[Object]).asScala.toMap ++
- typeVariables
- // Note that the fields are ordered by name.
- val fields = properties.map { property =>
- val readMethod = property.getReadMethod
- val encoder = encoderFor(readMethod.getGenericReturnType, seenTypeSet
+ c, classTV)
- // The existence of `javax.annotation.Nonnull`, means this field is
not nullable.
- val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
- EncoderField(
- property.getName,
- encoder,
- encoder.nullable && !hasNonNull,
- Metadata.empty,
- Option(readMethod.getName),
- Option(property.getWriteMethod).map(_.getName))
+
+ // if the properties is empty and this is not a top level enclosing
class, then we
+ // should not consider class as bean, as otherwise it will be treated as
empty schema
+ // and loose the data on deser.
+ if (properties.isEmpty && seenTypeSet.nonEmpty) {
+ findBestEncoder(Seq(c), seenTypeSet, typeVariables, None,
serializableEncodersOnly = true)
+ .getOrElse(throw
ExecutionErrors.cannotFindEncoderForTypeError(t.getTypeName))
+ } else {
+ // add type variables from inheritance hierarchy of the class
+ val parentClassesTypeMap =
+ JavaTypeUtils.getTypeArguments(c, classOf[Object]).asScala.toMap
+ val classTV = parentClassesTypeMap ++ typeVariables
+ // Note that the fields are ordered by name.
+ val fields = properties.map { property =>
+ val readMethod = property.getReadMethod
+ val methodReturnType = readMethod.getGenericReturnType
+ val encoder = encoderFor(methodReturnType, seenTypeSet + c, classTV)
+ // The existence of `javax.annotation.Nonnull`, means this field is
not nullable.
+ val hasNonNull = readMethod.isAnnotationPresent(classOf[Nonnull])
+ EncoderField(
+ property.getName,
+ encoder,
+ encoder.nullable && !hasNonNull,
+ Metadata.empty,
+ Option(readMethod.getName),
+ Option(property.getWriteMethod).map(_.getName))
+ }
+ // implies it cannot be assumed a BeanClass.
+ // Check if its super class or interface could be represented by an
Encoder
Review Comment:
These two comments are a bit confusing.
First commnet says "implies it cannot be assumed a BeanClass." . But then
why is it initialized with JavaBeanEncoder?
The second comment says "Check if its super class or interface could be
represented by an Encoder". It is not clear to me where exactly this check is
being done. Could you please elobrate it for me. Thanks.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]