Repository: spark Updated Branches: refs/heads/branch-1.6 faf094c7c -> a6190508b
http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0dbaeb8..9f8db39 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,8 @@ import java.sql.Date; import java.sql.Timestamp; import java.util.*; +import com.google.common.base.Objects; +import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator; import org.apache.spark.sql.test.TestSQLContext; import org.apache.spark.sql.catalyst.encoders.OuterScopes; import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.DecimalType; import org.apache.spark.sql.types.StructType; import static org.apache.spark.sql.functions.*; @@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable { context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); ds.collect(); } + + public class SmallBean implements Serializable { + private String a; + + private int b; + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SmallBean smallBean = (SmallBean) o; + return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a); + } + + @Override + public int hashCode() { + return Objects.hashCode(a, b); + } + } + + public class NestedSmallBean implements Serializable { + private SmallBean f; + + public SmallBean getF() { + return f; + } + + public void setF(SmallBean f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NestedSmallBean that = (NestedSmallBean) o; + return Objects.equal(f, that.f); + } + + @Override + public int hashCode() { + return Objects.hashCode(f); + } + } + + @Rule + public transient ExpectedException nullabilityCheck = ExpectedException.none(); + + @Test + public void testRuntimeNullabilityCheck() { + OuterScopes.addOuterScope(this); + + StructType schema = new StructType() + .add("f", new StructType() + .add("a", StringType, true) + .add("b", IntegerType, true), true); + + // Shouldn't throw runtime exception since it passes nullability check. + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", 1 + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); + + SmallBean smallBean = new SmallBean(); + smallBean.setA("hello"); + smallBean.setB(1); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + nestedSmallBean.setF(smallBean); + + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + { + Row row = new GenericRow(new Object[] { null }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + nullabilityCheck.expect(RuntimeException.class); + nullabilityCheck.expectMessage( + "Null value appeared in non-nullable field " + + "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", null + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset<NestedSmallBean> ds = df.as(Encoders.bean(NestedSmallBean.class)); + + ds.collect(); + } + } } http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 854dec0..0b7573c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -578,6 +578,21 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.showString(10) === expectedAnswer) } + test("showString: binary") { + val df = Seq( + ("12".getBytes, "ABC.".getBytes), + ("34".getBytes, "12346".getBytes) + ).toDF() + val expectedAnswer = """+-------+----------------+ + || _1| _2| + |+-------+----------------+ + ||[31 32]| [41 42 43 2E]| + ||[33 34]|[31 32 33 34 36]| + |+-------+----------------+ + |""".stripMargin + assert(df.showString(10) === expectedAnswer) + } + test("showString: minimum column width") { val df = Seq( (1, 1), http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c6b3991..c19b5a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,6 +24,7 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { @@ -438,6 +439,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(ds.toString == "[_1: int, _2: int]") } + test("showString: Kryo encoder") { + implicit val kryoEncoder = Encoders.kryo[KryoData] + val ds = Seq(KryoData(1), KryoData(2)).toDS() + + val expectedAnswer = """+-----------+ + || value| + |+-----------+ + ||KryoData(1)| + ||KryoData(2)| + |+-----------+ + |""".stripMargin + assert(ds.showString(10) === expectedAnswer) + } + test("Kryo encoder") { implicit val kryoEncoder = Encoders.kryo[KryoData] val ds = Seq(KryoData(1), KryoData(2)).toDS() @@ -493,12 +508,63 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData] assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3))) } -} + test("verify mismatching field names fail with a good error") { + val ds = Seq(ClassData("a", 1)).toDS() + val e = intercept[AnalysisException] { + ds.as[ClassData2].collect() + } + assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) + } + + test("runtime nullability check") { + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = true) + )) + + def buildDataset(rows: Row*): Dataset[NestedStruct] = { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + } + + checkAnswer( + buildDataset(Row(Row("hello", 1))), + NestedStruct(ClassData("hello", 1)) + ) + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) + + val message = intercept[RuntimeException] { + buildDataset(Row(Row("hello", null))).collect() + }.getMessage + + assert(message.contains( + "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." + )) + } + + test("SPARK-12478: top level null field") { + val ds0 = Seq(NestedStruct(null)).toDS() + checkAnswer(ds0, NestedStruct(null)) + checkAnswer(ds0.toDF(), Row(null)) + + val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() + checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkAnswer(ds1.toDF(), Row(Row(null))) + } +} case class ClassData(a: String, b: Int) +case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) +case class NestedStruct(f: ClassData) +case class DeepNestedStruct(f: NestedStruct) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bc22fb8..9246f55 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -21,10 +21,15 @@ import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.Queryable +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.{LogicalRDD, Queryable} abstract class QueryTest extends PlanTest { @@ -123,6 +128,8 @@ abstract class QueryTest extends PlanTest { |""".stripMargin) } + checkJsonFormat(analyzedDF) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -177,6 +184,97 @@ abstract class QueryTest extends PlanTest { s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + planWithCaching) } + + private def checkJsonFormat(df: DataFrame): Unit = { + val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. + logicalPlan.transform { + case _: MapPartitions[_, _] => return + case _: MapGroups[_, _, _] => return + case _: AppendColumns[_, _] => return + case _: CoGroup[_, _, _, _] => return + case _: LogicalRelation => return + }.transformAllExpressions { + case a: ImperativeAggregate => return + } + + val jsonString = try { + logicalPlan.toJSON + } catch { + case e => + fail( + s""" + |Failed to parse logical plan to JSON: + |${logicalPlan.treeString} + """.stripMargin, e) + } + + // bypass hive tests before we fix all corner cases in hive module. + if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return + + // scala function is not serializable to JSON, use null to replace them so that we can compare + // the plans later. + val normalized1 = logicalPlan.transformAllExpressions { + case udf: ScalaUDF => udf.copy(function = null) + case gen: UserDefinedGenerator => gen.copy(function = null) + } + + // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains + // these non-serializable stuff, and use these original ones to replace the null-placeholders + // in the logical plans parsed from JSON. + var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } + var localRelations = logicalPlan.collect { case l: LocalRelation => l } + var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + + val jsonBackPlan = try { + TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + } catch { + case e => + fail( + s""" + |Failed to rebuild the logical plan from JSON: + |${logicalPlan.treeString} + | + |${logicalPlan.prettyJson} + """.stripMargin, e) + } + + val normalized2 = jsonBackPlan transformDown { + case l: LogicalRDD => + val origin = logicalRDDs.head + logicalRDDs = logicalRDDs.drop(1) + LogicalRDD(l.output, origin.rdd)(sqlContext) + case l: LocalRelation => + val origin = localRelations.head + localRelations = localRelations.drop(1) + l.copy(data = origin.data) + case l: InMemoryRelation => + val origin = inMemoryRelations.head + inMemoryRelations = inMemoryRelations.drop(1) + InMemoryRelation( + l.output, + l.useCompression, + l.batchSize, + l.storageLevel, + origin.child, + l.tableName)( + origin.cachedColumnBuffers, + l._statistics, + origin._batchStats) + } + + assert(logicalRDDs.isEmpty) + assert(localRelations.isEmpty) + assert(inMemoryRelations.isEmpty) + + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: the logical plan parsed from json does not match the original one === + |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } } object QueryTest { http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f602f2f..2a11173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -65,6 +65,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this + + override def equals(other: Any): Boolean = other match { + case _: MyDenseVectorUDT => true + case _ => false + } } class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 08b291e..f099e14 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -728,6 +728,8 @@ private[hive] case class MetastoreRelation Objects.hashCode(databaseName, tableName, alias, output) } + override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil + @transient val hiveQlTable: Table = { // We start by constructing an API table as Hive performs several important transformations // internally when converting an API table to a QL table. http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index b30117f..d9b9ba4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -58,7 +58,7 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { - override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
