Repository: spark Updated Branches: refs/heads/master bce177552 -> 8f225e055
[SPARK-24548][SQL] Fix incorrect schema of Dataset with tuple encoders ## What changes were proposed in this pull request? When creating tuple expression encoders, we should give the serializer expressions of tuple items correct names, so we can have correct output schema when we use such tuple encoders. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #21576 from viirya/SPARK-24548. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8f225e05 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8f225e05 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8f225e05 Branch: refs/heads/master Commit: 8f225e055c2031ca85d61721ab712170ab4e50c1 Parents: bce1775 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Mon Jun 18 11:01:17 2018 -0700 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Mon Jun 18 11:01:17 2018 -0700 ---------------------------------------------------------------------- .../sql/catalyst/encoders/ExpressionEncoder.scala | 3 ++- .../org/apache/spark/sql/JavaDatasetSuite.java | 18 ++++++++++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/8f225e05/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index efc2882..cbea3c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -128,7 +128,7 @@ object ExpressionEncoder { case b: BoundReference if b == originalInputObject => newInputObject }) - if (enc.flat) { + val serializerExpr = if (enc.flat) { newSerializer.head } else { // For non-flat encoder, the input object is not top level anymore after being combined to @@ -146,6 +146,7 @@ object ExpressionEncoder { Invoke(Literal.fromObject(None), "equals", BooleanType, newInputObject :: Nil)) If(nullCheck, Literal.create(null, struct.dataType), struct) } + Alias(serializerExpr, s"_${index + 1}")() } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => http://git-wip-us.apache.org/repos/asf/spark/blob/8f225e05/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index c132cab..2c695fc 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -34,6 +34,7 @@ import com.google.common.base.Objects; import org.junit.*; import org.junit.rules.ExpectedException; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.*; import org.apache.spark.sql.*; @@ -337,6 +338,23 @@ public class JavaDatasetSuite implements Serializable { } @Test + public void testTupleEncoderSchema() { + Encoder<Tuple2<String, Tuple2<String,String>>> encoder = + Encoders.tuple(Encoders.STRING(), Encoders.tuple(Encoders.STRING(), Encoders.STRING())); + List<Tuple2<String, Tuple2<String, String>>> data = Arrays.asList(tuple2("1", tuple2("a", "b")), + tuple2("2", tuple2("c", "d"))); + Dataset<Row> ds1 = spark.createDataset(data, encoder).toDF("value1", "value2"); + + JavaPairRDD<String, Tuple2<String, String>> pairRDD = jsc.parallelizePairs(data); + Dataset<Row> ds2 = spark.createDataset(JavaPairRDD.toRDD(pairRDD), encoder) + .toDF("value1", "value2"); + + Assert.assertEquals(ds1.schema(), ds2.schema()); + Assert.assertEquals(ds1.select(expr("value2._1")).collectAsList(), + ds2.select(expr("value2._1")).collectAsList()); + } + + @Test public void testNestedTupleEncoder() { // test ((int, string), string) Encoder<Tuple2<Tuple2<Integer, String>, String>> encoder = http://git-wip-us.apache.org/repos/asf/spark/blob/8f225e05/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d477d78..093cee9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1466,6 +1466,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS() intercept[NullPointerException](ds.as[(Int, Int)].collect()) } + + test("SPARK-24548: Dataset with tuple encoders should have correct schema") { + val encoder = Encoders.tuple(newStringEncoder, + Encoders.tuple(newStringEncoder, newStringEncoder)) + + val data = Seq(("a", ("1", "2")), ("b", ("3", "4"))) + val rdd = sparkContext.parallelize(data) + + val ds1 = spark.createDataset(rdd) + val ds2 = spark.createDataset(rdd)(encoder) + assert(ds1.schema == ds2.schema) + checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*) + } } case class TestDataUnion(x: Int, y: Int, z: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org