viirya commented on code in PR #39615: URL: https://github.com/apache/spark/pull/39615#discussion_r1590399660
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala: ########## @@ -166,317 +148,58 @@ object JavaTypeInference { .filter(_.getReadMethod != null) } - private def getJavaBeanReadableAndWritableProperties( - beanClass: Class[_]): Array[PropertyDescriptor] = { - getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null) - } - - private def elementType(typeToken: TypeToken[_]): TypeToken[_] = { - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]] - val iterableSuperType = typeToken2.getSupertype(classOf[JIterable[_]]) - val iteratorType = iterableSuperType.resolveType(iteratorReturnType) - iteratorType.resolveType(nextReturnType) - } - - private def mapKeyValueType(typeToken: TypeToken[_]): (TypeToken[_], TypeToken[_]) = { - val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]] - val mapSuperType = typeToken2.getSupertype(classOf[JMap[_, _]]) - val keyType = elementType(mapSuperType.resolveType(keySetReturnType)) - val valueType = elementType(mapSuperType.resolveType(valuesReturnType)) - keyType -> valueType - } - - /** - * Returns the Spark SQL DataType for a given java class. Where this is not an exact mapping - * to a native type, an ObjectType is returned. - * - * Unlike `inferDataType`, this function doesn't do any massaging of types into the Spark SQL type - * system. As a result, ObjectType will be returned for things like boxed Integers. - */ - private def inferExternalType(cls: Class[_]): DataType = cls match { - case c if c == java.lang.Boolean.TYPE => BooleanType - case c if c == java.lang.Byte.TYPE => ByteType - case c if c == java.lang.Short.TYPE => ShortType - case c if c == java.lang.Integer.TYPE => IntegerType - case c if c == java.lang.Long.TYPE => LongType - case c if c == java.lang.Float.TYPE => FloatType - case c if c == java.lang.Double.TYPE => DoubleType - case c if c == classOf[Array[Byte]] => BinaryType - case _ => ObjectType(cls) - } - - /** - * Returns an expression that can be used to deserialize a Spark SQL representation to an object - * of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal - * 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed - * using `UnresolvedExtractValue`. - */ - def deserializerFor(beanClass: Class[_]): Expression = { - val typeToken = TypeToken.of(beanClass) - val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName) - val (dataType, nullable) = inferDataType(typeToken) - - // Assumes we are deserializing the first column of a row. - deserializerForWithNullSafetyAndUpcast(GetColumnByOrdinal(0, dataType), dataType, - nullable = nullable, walkedTypePath, deserializerFor(typeToken, _, walkedTypePath)) - } - - private def deserializerFor( - typeToken: TypeToken[_], - path: Expression, - walkedTypePath: WalkedTypePath): Expression = { - typeToken.getRawType match { - case c if !inferExternalType(c).isInstanceOf[ObjectType] => path - - case c if c == classOf[java.lang.Short] || - c == classOf[java.lang.Integer] || - c == classOf[java.lang.Long] || - c == classOf[java.lang.Double] || - c == classOf[java.lang.Float] || - c == classOf[java.lang.Byte] || - c == classOf[java.lang.Boolean] => - createDeserializerForTypesSupportValueOf(path, c) - - case c if c == classOf[java.time.LocalDate] => - createDeserializerForLocalDate(path) - - case c if c == classOf[java.sql.Date] => - createDeserializerForSqlDate(path) - - case c if c == classOf[java.time.Instant] => - createDeserializerForInstant(path) - - case c if c == classOf[java.sql.Timestamp] => - createDeserializerForSqlTimestamp(path) + private class ImplementsGenericInterface(interface: Class[_]) { + assert(interface.isInterface) + assert(interface.getTypeParameters.nonEmpty) - case c if c == classOf[java.time.LocalDateTime] => - createDeserializerForLocalDateTime(path) - - case c if c == classOf[java.time.Duration] => - createDeserializerForDuration(path) - - case c if c == classOf[java.time.Period] => - createDeserializerForPeriod(path) - - case c if c == classOf[java.lang.String] => - createDeserializerForString(path, returnNullable = true) - - case c if c == classOf[java.math.BigDecimal] => - createDeserializerForJavaBigDecimal(path, returnNullable = true) - - case c if c == classOf[java.math.BigInteger] => - createDeserializerForJavaBigInteger(path, returnNullable = true) - - case c if c.isArray => - val elementType = c.getComponentType - val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName) - val (dataType, elementNullable) = inferDataType(elementType) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - deserializerFor(typeToken.getComponentType, _, newTypePath)) - } - - val arrayData = UnresolvedMapObjects(mapFunction, path) - - val methodName = elementType match { - case c if c == java.lang.Integer.TYPE => "toIntArray" - case c if c == java.lang.Long.TYPE => "toLongArray" - case c if c == java.lang.Double.TYPE => "toDoubleArray" - case c if c == java.lang.Float.TYPE => "toFloatArray" - case c if c == java.lang.Short.TYPE => "toShortArray" - case c if c == java.lang.Byte.TYPE => "toByteArray" - case c if c == java.lang.Boolean.TYPE => "toBooleanArray" - // non-primitive - case _ => "array" - } - Invoke(arrayData, methodName, ObjectType(c)) - - case c if ttIsAssignableFrom(listType, typeToken) => - val et = elementType(typeToken) - val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName) - val (dataType, elementNullable) = inferDataType(et) - val mapFunction: Expression => Expression = element => { - // upcast the array element to the data type the encoder expected. - deserializerForWithNullSafetyAndUpcast( - element, - dataType, - nullable = elementNullable, - newTypePath, - deserializerFor(et, _, newTypePath)) - } - - UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) - - case _ if ttIsAssignableFrom(mapType, typeToken) => - val (keyType, valueType) = mapKeyValueType(typeToken) - val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName, - valueType.getType.getTypeName) - - val keyData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(keyType, p, newTypePath), - MapKeys(path)), - "array", - ObjectType(classOf[Array[Any]])) - - val valueData = - Invoke( - UnresolvedMapObjects( - p => deserializerFor(valueType, p, newTypePath), - MapValues(path)), - "array", - ObjectType(classOf[Array[Any]])) - - StaticInvoke( - ArrayBasedMapData.getClass, - ObjectType(classOf[JMap[_, _]]), - "toJavaMap", - keyData :: valueData :: Nil, - returnNullable = false) - - case other if other.isEnum => - createDeserializerForTypesSupportValueOf( - createDeserializerForString(path, returnNullable = false), - other) - - case other => - val properties = getJavaBeanReadableAndWritableProperties(other) - val setters = properties.map { p => - val fieldName = p.getName - val fieldType = typeToken.method(p.getReadMethod).getReturnType - val (dataType, nullable) = inferDataType(fieldType) - val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName) - // The existence of `javax.annotation.Nonnull`, means this field is not nullable. - val hasNonNull = p.getReadMethod.isAnnotationPresent(classOf[Nonnull]) - val setter = expressionWithNullSafety( - deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath), - newTypePath), - nullable = nullable && !hasNonNull, - newTypePath) - p.getWriteMethod.getName -> setter - }.toMap - - val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) - val result = InitializeJavaBean(newInstance, setters) - - expressions.If( - IsNull(path), - expressions.Literal.create(null, ObjectType(other)), - result - ) + def unapply(t: Type): Option[(Class[_], Array[Type])] = implementsInterface(t).map { cls => + cls -> findTypeArgumentsForInterface(t) } - } - /** - * Returns an expression for serializing an object of the given type to a Spark SQL - * representation. The input object is located at ordinal 0 of a row, i.e., - * `BoundReference(0, _)`. - */ - def serializerFor(beanClass: Class[_]): Expression = { - val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) - serializerFor(nullSafeInput, TypeToken.of(beanClass)) - } - - private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { - - def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { - val (dataType, nullable) = inferDataType(elementType) - if (ScalaReflection.isNativeType(dataType)) { - val cls = input.dataType.asInstanceOf[ObjectType].cls - if (cls.isArray && cls.getComponentType.isPrimitive) { - createSerializerForPrimitiveArray(input, dataType) - } else { - createSerializerForGenericArray(input, dataType, nullable = nullable) - } - } else { - createSerializerForMapObjects(input, ObjectType(elementType.getRawType), - serializerFor(_, elementType)) - } + @tailrec + private def implementsInterface(t: Type): Option[Class[_]] = t match { + case pt: ParameterizedType => implementsInterface(pt.getRawType) + case c: Class[_] if interface.isAssignableFrom(c) => Option(c) + case _ => None } - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - typeToken.getRawType match { - case c if c == classOf[String] => createSerializerForString(inputObject) - - case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) - - case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) - - case c if c == classOf[java.time.LocalDateTime] => - createSerializerForLocalDateTime(inputObject) - - case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) - - case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) - - case c if c == classOf[java.time.Duration] => createSerializerForJavaDuration(inputObject) - - case c if c == classOf[java.time.Period] => createSerializerForJavaPeriod(inputObject) - - case c if c == classOf[java.math.BigInteger] => - createSerializerForBigInteger(inputObject) - - case c if c == classOf[java.math.BigDecimal] => - createSerializerForBigDecimal(inputObject) - - case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) - case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject) - case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject) - case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject) - case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject) - case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject) - case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject) - - case _ if typeToken.isArray => - toCatalystArray(inputObject, typeToken.getComponentType) - - case _ if ttIsAssignableFrom(listType, typeToken) => - toCatalystArray(inputObject, elementType(typeToken)) - - case _ if ttIsAssignableFrom(mapType, typeToken) => - val (keyType, valueType) = mapKeyValueType(typeToken) - - createSerializerForMap( - inputObject, - MapElementInformation( - ObjectType(keyType.getRawType), - nullable = true, - serializerFor(_, keyType)), - MapElementInformation( - ObjectType(valueType.getRawType), - nullable = true, - serializerFor(_, valueType)) - ) - - case other if other.isEnum => - createSerializerForString( - Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) - - case other => - val properties = getJavaBeanReadableAndWritableProperties(other) Review Comment: Hmm, I'm not sure if this has been discussed, but this is a breaking change causing customer issue like https://issues.apache.org/jira/browse/SPARK-48073 when upgrading from Spark 3.2 to Spark 3.4. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org