Repository: spark
Updated Branches:
  refs/heads/branch-1.6 faf094c7c -> a6190508b


http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0dbaeb8..9f8db39 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -23,6 +23,8 @@ import java.sql.Date;
 import java.sql.Timestamp;
 import java.util.*;
 
+import com.google.common.base.Objects;
+import org.junit.rules.ExpectedException;
 import scala.Tuple2;
 import scala.Tuple3;
 import scala.Tuple4;
@@ -39,7 +41,6 @@ import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
 import org.apache.spark.sql.catalyst.encoders.OuterScopes;
 import org.apache.spark.sql.catalyst.expressions.GenericRow;
-import org.apache.spark.sql.types.DecimalType;
 import org.apache.spark.sql.types.StructType;
 
 import static org.apache.spark.sql.functions.*;
@@ -741,4 +742,127 @@ public class JavaDatasetSuite implements Serializable {
       context.createDataset(Arrays.asList(obj), 
Encoders.bean(SimpleJavaBean2.class));
     ds.collect();
   }
+
+  public class SmallBean implements Serializable {
+    private String a;
+
+    private int b;
+
+    public int getB() {
+      return b;
+    }
+
+    public void setB(int b) {
+      this.b = b;
+    }
+
+    public String getA() {
+      return a;
+    }
+
+    public void setA(String a) {
+      this.a = a;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      SmallBean smallBean = (SmallBean) o;
+      return b == smallBean.b && com.google.common.base.Objects.equal(a, 
smallBean.a);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(a, b);
+    }
+  }
+
+  public class NestedSmallBean implements Serializable {
+    private SmallBean f;
+
+    public SmallBean getF() {
+      return f;
+    }
+
+    public void setF(SmallBean f) {
+      this.f = f;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      NestedSmallBean that = (NestedSmallBean) o;
+      return Objects.equal(f, that.f);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(f);
+    }
+  }
+
+  @Rule
+  public transient ExpectedException nullabilityCheck = 
ExpectedException.none();
+
+  @Test
+  public void testRuntimeNullabilityCheck() {
+    OuterScopes.addOuterScope(this);
+
+    StructType schema = new StructType()
+      .add("f", new StructType()
+        .add("a", StringType, true)
+        .add("b", IntegerType, true), true);
+
+    // Shouldn't throw runtime exception since it passes nullability check.
+    {
+      Row row = new GenericRow(new Object[] {
+          new GenericRow(new Object[] {
+              "hello", 1
+          })
+      });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), 
schema);
+      Dataset<NestedSmallBean> ds = 
df.as(Encoders.bean(NestedSmallBean.class));
+
+      SmallBean smallBean = new SmallBean();
+      smallBean.setA("hello");
+      smallBean.setB(1);
+
+      NestedSmallBean nestedSmallBean = new NestedSmallBean();
+      nestedSmallBean.setF(smallBean);
+
+      Assert.assertEquals(ds.collectAsList(), 
Collections.singletonList(nestedSmallBean));
+    }
+
+    // Shouldn't throw runtime exception when parent object (`ClassData`) is 
null
+    {
+      Row row = new GenericRow(new Object[] { null });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), 
schema);
+      Dataset<NestedSmallBean> ds = 
df.as(Encoders.bean(NestedSmallBean.class));
+
+      NestedSmallBean nestedSmallBean = new NestedSmallBean();
+      Assert.assertEquals(ds.collectAsList(), 
Collections.singletonList(nestedSmallBean));
+    }
+
+    nullabilityCheck.expect(RuntimeException.class);
+    nullabilityCheck.expectMessage(
+      "Null value appeared in non-nullable field " +
+        "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int.");
+
+    {
+      Row row = new GenericRow(new Object[] {
+          new GenericRow(new Object[] {
+              "hello", null
+          })
+      });
+
+      DataFrame df = context.createDataFrame(Collections.singletonList(row), 
schema);
+      Dataset<NestedSmallBean> ds = 
df.as(Encoders.bean(NestedSmallBean.class));
+
+      ds.collect();
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 854dec0..0b7573c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -578,6 +578,21 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
     assert(df.showString(10) === expectedAnswer)
   }
 
+  test("showString: binary") {
+    val df = Seq(
+      ("12".getBytes, "ABC.".getBytes),
+      ("34".getBytes, "12346".getBytes)
+    ).toDF()
+    val expectedAnswer = """+-------+----------------+
+                           ||     _1|              _2|
+                           |+-------+----------------+
+                           ||[31 32]|   [41 42 43 2E]|
+                           ||[33 34]|[31 32 33 34 36]|
+                           |+-------+----------------+
+                           |""".stripMargin
+    assert(df.showString(10) === expectedAnswer)
+  }
+
   test("showString: minimum column width") {
     val df = Seq(
       (1, 1),

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index c6b3991..c19b5a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -24,6 +24,7 @@ import scala.language.postfixOps
 
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, 
StructType}
 
 
 class DatasetSuite extends QueryTest with SharedSQLContext {
@@ -438,6 +439,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     assert(ds.toString == "[_1: int, _2: int]")
   }
 
+  test("showString: Kryo encoder") {
+    implicit val kryoEncoder = Encoders.kryo[KryoData]
+    val ds = Seq(KryoData(1), KryoData(2)).toDS()
+
+    val expectedAnswer = """+-----------+
+                           ||      value|
+                           |+-----------+
+                           ||KryoData(1)|
+                           ||KryoData(2)|
+                           |+-----------+
+                           |""".stripMargin
+    assert(ds.showString(10) === expectedAnswer)
+  }
+
   test("Kryo encoder") {
     implicit val kryoEncoder = Encoders.kryo[KryoData]
     val ds = Seq(KryoData(1), KryoData(2)).toDS()
@@ -493,12 +508,63 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
     val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
     assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
   }
-}
 
+  test("verify mismatching field names fail with a good error") {
+    val ds = Seq(ClassData("a", 1)).toDS()
+    val e = intercept[AnalysisException] {
+      ds.as[ClassData2].collect()
+    }
+    assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, 
b]"), e.getMessage)
+  }
+
+  test("runtime nullability check") {
+    val schema = StructType(Seq(
+      StructField("f", StructType(Seq(
+        StructField("a", StringType, nullable = true),
+        StructField("b", IntegerType, nullable = false)
+      )), nullable = true)
+    ))
+
+    def buildDataset(rows: Row*): Dataset[NestedStruct] = {
+      val rowRDD = sqlContext.sparkContext.parallelize(rows)
+      sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct]
+    }
+
+    checkAnswer(
+      buildDataset(Row(Row("hello", 1))),
+      NestedStruct(ClassData("hello", 1))
+    )
+
+    // Shouldn't throw runtime exception when parent object (`ClassData`) is 
null
+    assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null)))
+
+    val message = intercept[RuntimeException] {
+      buildDataset(Row(Row("hello", null))).collect()
+    }.getMessage
+
+    assert(message.contains(
+      "Null value appeared in non-nullable field 
org.apache.spark.sql.ClassData.b of type Int."
+    ))
+  }
+
+  test("SPARK-12478: top level null field") {
+    val ds0 = Seq(NestedStruct(null)).toDS()
+    checkAnswer(ds0, NestedStruct(null))
+    checkAnswer(ds0.toDF(), Row(null))
+
+    val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS()
+    checkAnswer(ds1, DeepNestedStruct(NestedStruct(null)))
+    checkAnswer(ds1.toDF(), Row(Row(null)))
+  }
+}
 
 case class ClassData(a: String, b: Int)
+case class ClassData2(c: String, d: Int)
 case class ClassNullableData(a: String, b: Integer)
 
+case class NestedStruct(f: ClassData)
+case class DeepNestedStruct(f: NestedStruct)
+
 /**
  * A class used to test serialization using encoders. This class throws 
exceptions when using
  * Java serialization -- so the only way it can be "serialized" is through our 
encoders.

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index bc22fb8..9246f55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -21,10 +21,15 @@ import java.util.{Locale, TimeZone}
 
 import scala.collection.JavaConverters._
 
-import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.plans._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
+import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.execution.columnar.InMemoryRelation
-import org.apache.spark.sql.execution.Queryable
+import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.{LogicalRDD, Queryable}
 
 abstract class QueryTest extends PlanTest {
 
@@ -123,6 +128,8 @@ abstract class QueryTest extends PlanTest {
              |""".stripMargin)
     }
 
+    checkJsonFormat(analyzedDF)
+
     QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
       case Some(errorMessage) => fail(errorMessage)
       case None =>
@@ -177,6 +184,97 @@ abstract class QueryTest extends PlanTest {
       s"Expected query to contain $numCachedTables, but it actually had 
${cachedData.size}\n" +
         planWithCaching)
   }
+
+  private def checkJsonFormat(df: DataFrame): Unit = {
+    val logicalPlan = df.queryExecution.analyzed
+    // bypass some cases that we can't handle currently.
+    logicalPlan.transform {
+      case _: MapPartitions[_, _] => return
+      case _: MapGroups[_, _, _] => return
+      case _: AppendColumns[_, _] => return
+      case _: CoGroup[_, _, _, _] => return
+      case _: LogicalRelation => return
+    }.transformAllExpressions {
+      case a: ImperativeAggregate => return
+    }
+
+    val jsonString = try {
+      logicalPlan.toJSON
+    } catch {
+      case e =>
+        fail(
+          s"""
+             |Failed to parse logical plan to JSON:
+             |${logicalPlan.treeString}
+           """.stripMargin, e)
+    }
+
+    // bypass hive tests before we fix all corner cases in hive module.
+    if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return
+
+    // scala function is not serializable to JSON, use null to replace them so 
that we can compare
+    // the plans later.
+    val normalized1 = logicalPlan.transformAllExpressions {
+      case udf: ScalaUDF => udf.copy(function = null)
+      case gen: UserDefinedGenerator => gen.copy(function = null)
+    }
+
+    // RDDs/data are not serializable to JSON, so we need to collect 
LogicalPlans that contains
+    // these non-serializable stuff, and use these original ones to replace 
the null-placeholders
+    // in the logical plans parsed from JSON.
+    var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
+    var localRelations = logicalPlan.collect { case l: LocalRelation => l }
+    var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => 
i }
+
+    val jsonBackPlan = try {
+      TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
+    } catch {
+      case e =>
+        fail(
+          s"""
+             |Failed to rebuild the logical plan from JSON:
+             |${logicalPlan.treeString}
+             |
+             |${logicalPlan.prettyJson}
+           """.stripMargin, e)
+    }
+
+    val normalized2 = jsonBackPlan transformDown {
+      case l: LogicalRDD =>
+        val origin = logicalRDDs.head
+        logicalRDDs = logicalRDDs.drop(1)
+        LogicalRDD(l.output, origin.rdd)(sqlContext)
+      case l: LocalRelation =>
+        val origin = localRelations.head
+        localRelations = localRelations.drop(1)
+        l.copy(data = origin.data)
+      case l: InMemoryRelation =>
+        val origin = inMemoryRelations.head
+        inMemoryRelations = inMemoryRelations.drop(1)
+        InMemoryRelation(
+          l.output,
+          l.useCompression,
+          l.batchSize,
+          l.storageLevel,
+          origin.child,
+          l.tableName)(
+          origin.cachedColumnBuffers,
+          l._statistics,
+          origin._batchStats)
+    }
+
+    assert(logicalRDDs.isEmpty)
+    assert(localRelations.isEmpty)
+    assert(inMemoryRelations.isEmpty)
+
+    if (normalized1 != normalized2) {
+      fail(
+        s"""
+           |== FAIL: the logical plan parsed from json does not match the 
original one ===
+           |${sideBySide(logicalPlan.treeString, 
normalized2.treeString).mkString("\n")}
+          """.stripMargin)
+    }
+  }
 }
 
 object QueryTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index f602f2f..2a11173 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -65,6 +65,11 @@ private[sql] class MyDenseVectorUDT extends 
UserDefinedType[MyDenseVector] {
   override def userClass: Class[MyDenseVector] = classOf[MyDenseVector]
 
   private[spark] override def asNullable: MyDenseVectorUDT = this
+
+  override def equals(other: Any): Boolean = other match {
+    case _: MyDenseVectorUDT => true
+    case _ => false
+  }
 }
 
 class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with 
ParquetTest {

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 08b291e..f099e14 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -728,6 +728,8 @@ private[hive] case class MetastoreRelation
     Objects.hashCode(databaseName, tableName, alias, output)
   }
 
+  override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: 
Nil
+
   @transient val hiveQlTable: Table = {
     // We start by constructing an API table as Hive performs several 
important transformations
     // internally when converting an API table to a QL table.

http://git-wip-us.apache.org/repos/asf/spark/blob/a6190508/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index b30117f..d9b9ba4 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -58,7 +58,7 @@ case class ScriptTransformation(
     ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext)
   extends UnaryNode {
 
-  override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+  override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
 
   private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to