Repository: spark
Updated Branches:
  refs/heads/branch-1.6 6e3e3c648 -> 74a230676


[SPARK-11856][SQL] add type cast if the real type is different but compatible 
with encoder schema

When we build the `fromRowExpression` for an encoder, we set up a lot of 
"unresolved" stuff and lost the required data type, which may lead to runtime 
error if the real type doesn't match the encoder's schema.
For example, we build an encoder for `case class Data(a: Int, b: String)` and 
the real type is `[a: int, b: long]`, then we will hit runtime error and say 
that we can't construct class `Data` with int and long, because we lost the 
information that `b` should be a string.

Author: Wenchen Fan <[email protected]>

Closes #9840 from cloud-fan/err-msg.

(cherry picked from commit 9df24624afedd993a39ab46c8211ae153aedef1a)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/74a23067
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/74a23067
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/74a23067

Branch: refs/heads/branch-1.6
Commit: 74a2306763161fc04c9d3e7de186a6b31617faf4
Parents: 6e3e3c6
Author: Wenchen Fan <[email protected]>
Authored: Tue Dec 1 10:24:53 2015 -0800
Committer: Michael Armbrust <[email protected]>
Committed: Tue Dec 1 10:25:11 2015 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    |  93 ++++++++--
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  40 +++++
 .../catalyst/analysis/HiveTypeCoercion.scala    |   2 +-
 .../catalyst/encoders/ExpressionEncoder.scala   |   4 +-
 .../spark/sql/catalyst/expressions/Cast.scala   |   9 +
 .../expressions/complexTypeCreator.scala        |   2 +-
 .../apache/spark/sql/types/DecimalType.scala    |  12 ++
 .../encoders/EncoderResolutionSuite.scala       | 180 +++++++++++++++++++
 .../spark/sql/DatasetAggregatorSuite.scala      |   4 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  21 ++-
 10 files changed, 335 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index d133ad3..9b6b5b8 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -18,9 +18,8 @@
 package org.apache.spark.sql.catalyst
 
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, 
UnresolvedAttribute}
-import org.apache.spark.sql.catalyst.util.{GenericArrayData, 
ArrayBasedMapData, ArrayData, DateTimeUtils}
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, 
ArrayBasedMapData, DateTimeUtils}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.Utils
@@ -117,31 +116,75 @@ object ScalaReflection extends ScalaReflection {
    * from ordinal 0 (since there are no names to map to).  The actual location 
can be moved by
    * calling resolve/bind with a new schema.
    */
-  def constructorFor[T : TypeTag]: Expression = constructorFor(localTypeOf[T], 
None)
+  def constructorFor[T : TypeTag]: Expression = {
+    val tpe = localTypeOf[T]
+    val clsName = getClassNameFromType(tpe)
+    val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
+    constructorFor(tpe, None, walkedTypePath)
+  }
 
   private def constructorFor(
       tpe: `Type`,
-      path: Option[Expression]): Expression = ScalaReflectionLock.synchronized 
{
+      path: Option[Expression],
+      walkedTypePath: Seq[String]): Expression = 
ScalaReflectionLock.synchronized {
 
     /** Returns the current path with a sub-field extracted. */
-    def addToPath(part: String): Expression = path
-      .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
-      .getOrElse(UnresolvedAttribute(part))
+    def addToPath(part: String, dataType: DataType, walkedTypePath: 
Seq[String]): Expression = {
+      val newPath = path
+        .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
+        .getOrElse(UnresolvedAttribute(part))
+      upCastToExpectedType(newPath, dataType, walkedTypePath)
+    }
 
     /** Returns the current path with a field at ordinal extracted. */
-    def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path
-      .map(p => GetStructField(p, ordinal))
-      .getOrElse(BoundReference(ordinal, dataType, false))
+    def addToPathOrdinal(
+        ordinal: Int,
+        dataType: DataType,
+        walkedTypePath: Seq[String]): Expression = {
+      val newPath = path
+        .map(p => GetStructField(p, ordinal))
+        .getOrElse(BoundReference(ordinal, dataType, false))
+      upCastToExpectedType(newPath, dataType, walkedTypePath)
+    }
 
     /** Returns the current path or `BoundReference`. */
-    def getPath: Expression = path.getOrElse(BoundReference(0, 
schemaFor(tpe).dataType, true))
+    def getPath: Expression = {
+      val dataType = schemaFor(tpe).dataType
+      if (path.isDefined) {
+        path.get
+      } else {
+        upCastToExpectedType(BoundReference(0, dataType, true), dataType, 
walkedTypePath)
+      }
+    }
+
+    /**
+     * When we build the `fromRowExpression` for an encoder, we set up a lot 
of "unresolved" stuff
+     * and lost the required data type, which may lead to runtime error if the 
real type doesn't
+     * match the encoder's schema.
+     * For example, we build an encoder for `case class Data(a: Int, b: 
String)` and the real type
+     * is [a: int, b: long], then we will hit runtime error and say that we 
can't construct class
+     * `Data` with int and long, because we lost the information that `b` 
should be a string.
+     *
+     * This method help us "remember" the required data type by adding a 
`UpCast`.  Note that we
+     * don't need to cast struct type because there must be 
`UnresolvedExtractValue` or
+     * `GetStructField` wrapping it, thus we only need to handle leaf type.
+     */
+    def upCastToExpectedType(
+        expr: Expression,
+        expected: DataType,
+        walkedTypePath: Seq[String]): Expression = expected match {
+      case _: StructType => expr
+      case _ => UpCast(expr, expected, walkedTypePath)
+    }
 
     tpe match {
       case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath
 
       case t if t <:< localTypeOf[Option[_]] =>
         val TypeRef(_, _, Seq(optType)) = t
-        WrapOption(constructorFor(optType, path))
+        val className = getClassNameFromType(optType)
+        val newTypePath = s"""- option value class: "$className"""" +: 
walkedTypePath
+        WrapOption(constructorFor(optType, path, newTypePath))
 
       case t if t <:< localTypeOf[java.lang.Integer] =>
         val boxedType = classOf[java.lang.Integer]
@@ -219,9 +262,11 @@ object ScalaReflection extends ScalaReflection {
         primitiveMethod.map { method =>
           Invoke(getPath, method, arrayClassFor(elementType))
         }.getOrElse {
+          val className = getClassNameFromType(elementType)
+          val newTypePath = s"""- array element class: "$className"""" +: 
walkedTypePath
           Invoke(
             MapObjects(
-              p => constructorFor(elementType, Some(p)),
+              p => constructorFor(elementType, Some(p), newTypePath),
               getPath,
               schemaFor(elementType).dataType),
             "array",
@@ -230,10 +275,12 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Seq[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
+        val className = getClassNameFromType(elementType)
+        val newTypePath = s"""- array element class: "$className"""" +: 
walkedTypePath
         val arrayData =
           Invoke(
             MapObjects(
-              p => constructorFor(elementType, Some(p)),
+              p => constructorFor(elementType, Some(p), newTypePath),
               getPath,
               schemaFor(elementType).dataType),
             "array",
@@ -246,12 +293,13 @@ object ScalaReflection extends ScalaReflection {
           arrayData :: Nil)
 
       case t if t <:< localTypeOf[Map[_, _]] =>
+        // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
         val keyData =
           Invoke(
             MapObjects(
-              p => constructorFor(keyType, Some(p)),
+              p => constructorFor(keyType, Some(p), walkedTypePath),
               Invoke(getPath, "keyArray", 
ArrayType(schemaFor(keyType).dataType)),
               schemaFor(keyType).dataType),
             "array",
@@ -260,7 +308,7 @@ object ScalaReflection extends ScalaReflection {
         val valueData =
           Invoke(
             MapObjects(
-              p => constructorFor(valueType, Some(p)),
+              p => constructorFor(valueType, Some(p), walkedTypePath),
               Invoke(getPath, "valueArray", 
ArrayType(schemaFor(valueType).dataType)),
               schemaFor(valueType).dataType),
             "array",
@@ -297,12 +345,19 @@ object ScalaReflection extends ScalaReflection {
           val fieldName = p.name.toString
           val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, 
actualTypeArgs)
           val dataType = schemaFor(fieldType).dataType
-
+          val clsName = getClassNameFromType(fieldType)
+          val newTypePath = s"""- field (class: "$clsName", name: 
"$fieldName")""" +: walkedTypePath
           // For tuples, we based grab the inner fields by ordinal instead of 
name.
           if (cls.getName startsWith "scala.Tuple") {
-            constructorFor(fieldType, Some(addToPathOrdinal(i, dataType)))
+            constructorFor(
+              fieldType,
+              Some(addToPathOrdinal(i, dataType, newTypePath)),
+              newTypePath)
           } else {
-            constructorFor(fieldType, Some(addToPath(fieldName)))
+            constructorFor(
+              fieldType,
+              Some(addToPath(fieldName, dataType, newTypePath)),
+              newTypePath)
           }
         }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b8f212f..765327c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
       ResolveReferences ::
       ResolveGroupingAnalytics ::
       ResolvePivot ::
+      ResolveUpCast ::
       ResolveSortReferences ::
       ResolveGenerate ::
       ResolveFunctions ::
@@ -1182,3 +1183,42 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
     }
   }
 }
+
+/**
+ * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast 
may truncate.
+ */
+object ResolveUpCast extends Rule[LogicalPlan] {
+  private def fail(from: Expression, to: DataType, walkedTypePath: 
Seq[String]) = {
+    throw new AnalysisException(s"Cannot up cast `${from.prettyString}` from " 
+
+      s"${from.dataType.simpleString} to ${to.simpleString} as it may 
truncate\n" +
+      "The type path of the target object is:\n" + walkedTypePath.mkString("", 
"\n", "\n") +
+      "You can either add an explicit cast to the input data or choose a 
higher precision " +
+      "type of the field in the target object")
+  }
+
+  private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean 
= {
+    val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from)
+    val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to)
+    toPrecedence > 0 && fromPrecedence > toPrecedence
+  }
+
+  def apply(plan: LogicalPlan): LogicalPlan = {
+    plan transformAllExpressions {
+      case u @ UpCast(child, _, _) if !child.resolved => u
+
+      case UpCast(child, dataType, walkedTypePath) => (child.dataType, 
dataType) match {
+        case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) =>
+          fail(child, to, walkedTypePath)
+        case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) =>
+          fail(child, to, walkedTypePath)
+        case (from, to) if illegalNumericPrecedence(from, to) =>
+          fail(child, to, walkedTypePath)
+        case (TimestampType, DateType) =>
+          fail(child, DateType, walkedTypePath)
+        case (StringType, to: NumericType) =>
+          fail(child, to, walkedTypePath)
+        case _ => Cast(child, dataType)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index f90fc3c..29502a5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -53,7 +53,7 @@ object HiveTypeCoercion {
 
   // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
   // The conversion for integral and floating point types have a linear 
widening hierarchy:
-  private val numericPrecedence =
+  private[sql] val numericPrecedence =
     IndexedSeq(
       ByteType,
       ShortType,

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0c10a56..06ffe86 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,6 +28,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtract
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.expressions._
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, 
GenerateUnsafeProjection}
+import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.types.{StructField, ObjectType, StructType}
@@ -235,12 +236,13 @@ case class ExpressionEncoder[T](
 
     val plan = Project(Alias(unbound, "")() :: Nil, LocalRelation(schema))
     val analyzedPlan = SimpleAnalyzer.execute(plan)
+    val optimizedPlan = SimplifyCasts(analyzedPlan)
 
     // In order to construct instances of inner classes (for example those 
declared in a REPL cell),
     // we need an instance of the outer scope.  This rule substitues those 
outer objects into
     // expressions that are missing them by looking up the name in the 
SQLContexts `outerScopes`
     // registry.
-    copy(fromRowExpression = analyzedPlan.expressions.head.children.head 
transform {
+    copy(fromRowExpression = optimizedPlan.expressions.head.children.head 
transform {
       case n: NewInstance if n.outerPointer.isEmpty && n.cls.isMemberClass =>
         val outer = outerScopes.get(n.cls.getDeclaringClass.getName)
         if (outer == null) {

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 533d17e..79e0438 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -915,3 +915,12 @@ case class Cast(child: Expression, dataType: DataType)
       """
   }
 }
+
+/**
+ * Cast the child expression to the target data type, but will throw error if 
the cast might
+ * truncate, e.g. long -> int, timestamp -> data.
+ */
+case class UpCast(child: Expression, dataType: DataType, walkedTypePath: 
Seq[String])
+  extends UnaryExpression with Unevaluable {
+  override lazy val resolved = false
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index 1854dfa..72cc89c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -126,7 +126,7 @@ case class CreateStruct(children: Seq[Expression]) extends 
Expression {
 case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
 
   /**
-   * Returns Aliased [[Expressions]] that could be used to construct a 
flattened version of this
+   * Returns Aliased [[Expression]]s that could be used to construct a 
flattened version of this
    * StructType.
    */
   def flatten: Seq[NamedExpression] = valExprs.zip(names).map {

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
index 0cd352d..ce45245 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala
@@ -91,6 +91,18 @@ case class DecimalType(precision: Int, scale: Int) extends 
FractionalType {
   }
 
   /**
+   * Returns whether this DecimalType is tighter than `other`. If yes, it 
means `this`
+   * can be casted into `other` safely without losing any precision or range.
+   */
+  private[sql] def isTighterThan(other: DataType): Boolean = other match {
+    case dt: DecimalType =>
+      (precision - scale) <= (dt.precision - dt.scale) && scale <= dt.scale
+    case dt: IntegralType =>
+      isTighterThan(DecimalType.forType(dt))
+    case _ => false
+  }
+
+  /**
    * The default size of a value of the DecimalType is 4096 bytes.
    */
   override def defaultSize: Int = 4096

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
new file mode 100644
index 0000000..0289988
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.encoders
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.types._
+
+case class StringLongClass(a: String, b: Long)
+
+case class StringIntClass(a: String, b: Int)
+
+case class ComplexClass(a: Long, b: StringLongClass)
+
+class EncoderResolutionSuite extends PlanTest {
+  test("real type doesn't match encoder schema but they are compatible: 
product") {
+    val encoder = ExpressionEncoder[StringLongClass]
+    val cls = classOf[StringLongClass]
+
+    {
+      val attrs = Seq('a.string, 'b.int)
+      val fromRowExpr: Expression = encoder.resolve(attrs, 
null).fromRowExpression
+      val expected: Expression = NewInstance(
+        cls,
+        toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil,
+        false,
+        ObjectType(cls))
+      compareExpressions(fromRowExpr, expected)
+    }
+
+    {
+      val attrs = Seq('a.int, 'b.long)
+      val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression
+      val expected = NewInstance(
+        cls,
+        toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil,
+        false,
+        ObjectType(cls))
+      compareExpressions(fromRowExpr, expected)
+    }
+  }
+
+  test("real type doesn't match encoder schema but they are compatible: nested 
product") {
+    val encoder = ExpressionEncoder[ComplexClass]
+    val innerCls = classOf[StringLongClass]
+    val cls = classOf[ComplexClass]
+
+    val structType = new StructType().add("a", IntegerType).add("b", LongType)
+    val attrs = Seq('a.int, 'b.struct(structType))
+    val fromRowExpr: Expression = encoder.resolve(attrs, 
null).fromRowExpression
+    val expected: Expression = NewInstance(
+      cls,
+      Seq(
+        'a.int.cast(LongType),
+        If(
+          'b.struct(structType).isNull,
+          Literal.create(null, ObjectType(innerCls)),
+          NewInstance(
+            innerCls,
+            Seq(
+              toExternalString(
+                GetStructField('b.struct(structType), 0, 
Some("a")).cast(StringType)),
+              GetStructField('b.struct(structType), 1, Some("b"))),
+            false,
+            ObjectType(innerCls))
+        )),
+      false,
+      ObjectType(cls))
+    compareExpressions(fromRowExpr, expected)
+  }
+
+  test("real type doesn't match encoder schema but they are compatible: tupled 
encoder") {
+    val encoder = ExpressionEncoder.tuple(
+      ExpressionEncoder[StringLongClass],
+      ExpressionEncoder[Long])
+    val cls = classOf[StringLongClass]
+
+    val structType = new StructType().add("a", StringType).add("b", ByteType)
+    val attrs = Seq('a.struct(structType), 'b.int)
+    val fromRowExpr: Expression = encoder.resolve(attrs, 
null).fromRowExpression
+    val expected: Expression = NewInstance(
+      classOf[Tuple2[_, _]],
+      Seq(
+        NewInstance(
+          cls,
+          Seq(
+            toExternalString(GetStructField('a.struct(structType), 0, 
Some("a"))),
+            GetStructField('a.struct(structType), 1, 
Some("b")).cast(LongType)),
+          false,
+          ObjectType(cls)),
+        'b.int.cast(LongType)),
+      false,
+      ObjectType(classOf[Tuple2[_, _]]))
+    compareExpressions(fromRowExpr, expected)
+  }
+
+  private def toExternalString(e: Expression): Expression = {
+    Invoke(e, "toString", ObjectType(classOf[String]), Nil)
+  }
+
+  test("throw exception if real type is not compatible with encoder schema") {
+    val msg1 = intercept[AnalysisException] {
+      ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
+    }.message
+    assert(msg1 ==
+      s"""
+         |Cannot up cast `b` from bigint to int as it may truncate
+         |The type path of the target object is:
+         |- field (class: "scala.Int", name: "b")
+         |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass"
+         |You can either add an explicit cast to the input data or choose a 
higher precision type
+       """.stripMargin.trim + " of the field in the target object")
+
+    val msg2 = intercept[AnalysisException] {
+      val structType = new StructType().add("a", StringType).add("b", 
DecimalType.SYSTEM_DEFAULT)
+      ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 
'b.struct(structType)), null)
+    }.message
+    assert(msg2 ==
+      s"""
+         |Cannot up cast `b.b` from decimal(38,18) to bigint as it may truncate
+         |The type path of the target object is:
+         |- field (class: "scala.Long", name: "b")
+         |- field (class: 
"org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b")
+         |- root class: "org.apache.spark.sql.catalyst.encoders.ComplexClass"
+         |You can either add an explicit cast to the input data or choose a 
higher precision type
+       """.stripMargin.trim + " of the field in the target object")
+  }
+
+  // test for leaf types
+  castSuccess[Int, Long]
+  castSuccess[java.sql.Date, java.sql.Timestamp]
+  castSuccess[Long, String]
+  castSuccess[Int, java.math.BigDecimal]
+  castSuccess[Long, java.math.BigDecimal]
+
+  castFail[Long, Int]
+  castFail[java.sql.Timestamp, java.sql.Date]
+  castFail[java.math.BigDecimal, Double]
+  castFail[Double, java.math.BigDecimal]
+  castFail[java.math.BigDecimal, Int]
+  castFail[String, Long]
+
+
+  private def castSuccess[T: TypeTag, U: TypeTag]: Unit = {
+    val from = ExpressionEncoder[T]
+    val to = ExpressionEncoder[U]
+    val catalystType = from.schema.head.dataType.simpleString
+    test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should 
success") {
+      to.resolve(from.schema.toAttributes, null)
+    }
+  }
+
+  private def castFail[T: TypeTag, U: TypeTag]: Unit = {
+    val from = ExpressionEncoder[T]
+    val to = ExpressionEncoder[U]
+    val catalystType = from.schema.head.dataType.simpleString
+    test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should 
fail") {
+      intercept[AnalysisException](to.resolve(from.schema.toAttributes, null))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 19dce5d..c6d2bf0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -131,9 +131,9 @@ class DatasetAggregatorSuite extends QueryTest with 
SharedSQLContext {
     checkAnswer(
       ds.groupBy(_._1).agg(
         sum(_._2),
-        expr("sum(_2)").as[Int],
+        expr("sum(_2)").as[Long],
         count("*")),
-      ("a", 30, 30, 2L), ("b", 3, 3, 2L), ("c", 1, 1, 1L))
+      ("a", 30, 30L, 2L), ("b", 3, 3L, 2L), ("c", 1, 1L, 1L))
   }
 
   test("typed aggregation: complex case") {

http://git-wip-us.apache.org/repos/asf/spark/blob/74a23067/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index a2c8d20..542e4d6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -335,24 +335,24 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkAnswer(
-      ds.groupBy(_._1).agg(sum("_2").as[Int]),
-      ("a", 30), ("b", 3), ("c", 1))
+      ds.groupBy(_._1).agg(sum("_2").as[Long]),
+      ("a", 30L), ("b", 3L), ("c", 1L))
   }
 
   test("typed aggregation: expr, expr") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkAnswer(
-      ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]),
-      ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L))
+      ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long]),
+      ("a", 30L, 32L), ("b", 3L, 5L), ("c", 1L, 2L))
   }
 
   test("typed aggregation: expr, expr, expr") {
     val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
 
     checkAnswer(
-      ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], 
count("*").as[Long]),
-      ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L))
+      ds.groupBy(_._1).agg(sum("_2").as[Long], sum($"_2" + 1).as[Long], 
count("*")),
+      ("a", 30L, 32L, 2L), ("b", 3L, 5L, 2L), ("c", 1L, 2L, 1L))
   }
 
   test("typed aggregation: expr, expr, expr, expr") {
@@ -360,11 +360,11 @@ class DatasetSuite extends QueryTest with 
SharedSQLContext {
 
     checkAnswer(
       ds.groupBy(_._1).agg(
-        sum("_2").as[Int],
+        sum("_2").as[Long],
         sum($"_2" + 1).as[Long],
         count("*").as[Long],
         avg("_2").as[Double]),
-      ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0))
+      ("a", 30L, 32L, 2L, 15.0), ("b", 3L, 5L, 2L, 1.5), ("c", 1L, 2L, 1L, 
1.0))
   }
 
   test("cogroup") {
@@ -476,6 +476,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       ((nullInt, "1"), (new java.lang.Integer(22), "2")),
       ((new java.lang.Integer(22), "2"), (new java.lang.Integer(22), "2")))
   }
+
+  test("change encoder with compatible schema") {
+    val ds = Seq(2 -> 2.toByte, 3 -> 3.toByte).toDF("a", "b").as[ClassData]
+    assert(ds.collect().toSeq == Seq(ClassData("2", 2), ClassData("3", 3)))
+  }
 }
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to