Repository: spark Updated Branches: refs/heads/master ab1650d29 -> 434ada12a
[SPARK-17952][SQL] Nested Java beans support in createDataFrame ## What changes were proposed in this pull request? When constructing a DataFrame from a Java bean, using nested beans throws an error despite [documentation](http://spark.apache.org/docs/latest/sql-programming-guide.html#inferring-the-schema-using-reflection) stating otherwise. This PR aims to add that support. This PR does not yet add nested beans support in array or List fields. This can be added later or in another PR. ## How was this patch tested? Nested bean was added to the appropriate unit test. Also manually tested in Spark shell on code emulating the referenced JIRA: ``` scala> import scala.beans.BeanProperty import scala.beans.BeanProperty scala> class SubCategory(BeanProperty var id: String, BeanProperty var name: String) extends Serializable defined class SubCategory scala> class Category(BeanProperty var id: String, BeanProperty var subCategory: SubCategory) extends Serializable defined class Category scala> import scala.collection.JavaConverters._ import scala.collection.JavaConverters._ scala> spark.createDataFrame(Seq(new Category("s-111", new SubCategory("sc-111", "Sub-1"))).asJava, classOf[Category]) java.lang.IllegalArgumentException: The value (SubCategory65130cf2) of the type (SubCategory) cannot be converted to struct<id:string,name:string> at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:262) at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:238) at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103) at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:396) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:1108) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1$$anonfun$apply$1.apply(SQLContext.scala:1108) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234) at scala.collection.IndexedSeqOptimized$class.foreach(IndexedSeqOptimized.scala:33) at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:186) at scala.collection.TraversableLike$class.map(TraversableLike.scala:234) at scala.collection.mutable.ArrayOps$ofRef.map(ArrayOps.scala:186) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:1108) at org.apache.spark.sql.SQLContext$$anonfun$beansToRows$1.apply(SQLContext.scala:1106) at scala.collection.Iterator$$anon$11.next(Iterator.scala:410) at scala.collection.Iterator$class.toStream(Iterator.scala:1320) at scala.collection.AbstractIterator.toStream(Iterator.scala:1334) at scala.collection.TraversableOnce$class.toSeq(TraversableOnce.scala:298) at scala.collection.AbstractIterator.toSeq(Iterator.scala:1334) at org.apache.spark.sql.SparkSession.createDataFrame(SparkSession.scala:423) ... 51 elided ``` New behavior: ``` scala> spark.createDataFrame(Seq(new Category("s-111", new SubCategory("sc-111", "Sub-1"))).asJava, classOf[Category]) res0: org.apache.spark.sql.DataFrame = [id: string, subCategory: struct<id: string, name: string>] scala> res0.show() +-----+---------------+ | id| subCategory| +-----+---------------+ |s-111|[sc-111, Sub-1]| +-----+---------------+ ``` Closes #22527 from michalsenkyr/SPARK-17952. Authored-by: Michal Senkyr <mike.sen...@gmail.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/434ada12 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/434ada12 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/434ada12 Branch: refs/heads/master Commit: 434ada12a06d1d2d3cb19c4eac5a52f330bb236c Parents: ab1650d Author: Michal Senkyr <mike.sen...@gmail.com> Authored: Fri Oct 5 17:48:52 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Fri Oct 5 17:48:52 2018 +0900 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/SQLContext.scala | 29 +++++++++++++------ .../apache/spark/sql/JavaDataFrameSuite.java | 30 +++++++++++++++++++- 2 files changed, 50 insertions(+), 9 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/434ada12/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index af60184..dfb12f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -1098,16 +1098,29 @@ object SQLContext { data: Iterator[_], beanClass: Class[_], attrs: Seq[AttributeReference]): Iterator[InternalRow] = { - val extractors = - JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod) - val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => - (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + def createStructConverter(cls: Class[_], fieldTypes: Seq[DataType]): Any => InternalRow = { + val methodConverters = + JavaTypeInference.getJavaBeanReadableProperties(cls).zip(fieldTypes) + .map { case (property, fieldType) => + val method = property.getReadMethod + method -> createConverter(method.getReturnType, fieldType) + } + value => + if (value == null) { + null + } else { + new GenericInternalRow( + methodConverters.map { case (method, converter) => + converter(method.invoke(value)) + }) + } } - data.map { element => - new GenericInternalRow( - methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) } - ): InternalRow + def createConverter(cls: Class[_], dataType: DataType): Any => Any = dataType match { + case struct: StructType => createStructConverter(cls, struct.map(_.dataType)) + case _ => CatalystTypeConverters.createToCatalystConverter(dataType) } + val dataConverter = createStructConverter(beanClass, attrs.map(_.dataType)) + data.map(dataConverter) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/434ada12/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 00f41d6..a05afa4 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -134,6 +134,8 @@ public class JavaDataFrameSuite { private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 }); private List<String> d = Arrays.asList("floppy", "disk"); private BigInteger e = new BigInteger("1234567"); + private NestedBean f = new NestedBean(); + private NestedBean g = null; public double getA() { return a; @@ -152,6 +154,22 @@ public class JavaDataFrameSuite { } public BigInteger getE() { return e; } + + public NestedBean getF() { + return f; + } + + public NestedBean getG() { + return g; + } + + public static class NestedBean implements Serializable { + private int a = 1; + + public int getA() { + return a; + } + } } void validateDataFrameWithBeans(Bean bean, Dataset<Row> df) { @@ -171,7 +189,14 @@ public class JavaDataFrameSuite { schema.apply("d")); Assert.assertEquals(new StructField("e", DataTypes.createDecimalType(38,0), true, Metadata.empty()), schema.apply("e")); - Row first = df.select("a", "b", "c", "d", "e").first(); + StructType nestedBeanType = + DataTypes.createStructType(Collections.singletonList(new StructField( + "a", IntegerType$.MODULE$, false, Metadata.empty()))); + Assert.assertEquals(new StructField("f", nestedBeanType, true, Metadata.empty()), + schema.apply("f")); + Assert.assertEquals(new StructField("g", nestedBeanType, true, Metadata.empty()), + schema.apply("g")); + Row first = df.select("a", "b", "c", "d", "e", "f", "g").first(); Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0); // Now Java lists and maps are converted to Scala Seq's and Map's. Once we get a Seq below, // verify that it has the expected length, and contains expected elements. @@ -192,6 +217,9 @@ public class JavaDataFrameSuite { } // Java.math.BigInteger is equivalent to Spark Decimal(38,0) Assert.assertEquals(new BigDecimal(bean.getE()), first.getDecimal(4)); + Row nested = first.getStruct(5); + Assert.assertEquals(bean.getF().getA(), nested.getInt(0)); + Assert.assertTrue(first.isNullAt(6)); } @Test --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org