Repository: spark
Updated Branches:
  refs/heads/master 044f7ecbf -> ffc57b011


[SPARK-20302][SQL] Short circuit cast when from and to types are structurally 
the same

## What changes were proposed in this pull request?
When we perform a cast expression and the from and to types are structurally 
the same (having the same structure but different field names), we should be 
able to skip the actual cast.

## How was this patch tested?
Added unit tests for the newly introduced functions.

Author: Reynold Xin <r...@databricks.com>

Closes #17614 from rxin/SPARK-20302.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ffc57b01
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ffc57b01
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ffc57b01

Branch: refs/heads/master
Commit: ffc57b0118b58de57520967d8e8730b11baad507
Parents: 044f7ecb
Author: Reynold Xin <r...@databricks.com>
Authored: Wed Apr 12 01:30:00 2017 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Wed Apr 12 01:30:00 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/expressions/Cast.scala   | 65 +++++++++++++-------
 .../org/apache/spark/sql/types/DataType.scala   | 26 ++++++++
 .../sql/catalyst/expressions/CastSuite.scala    | 14 +++++
 .../apache/spark/sql/types/DataTypeSuite.scala  | 31 ++++++++++
 4 files changed, 113 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 1049915..bb1273f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -462,35 +462,54 @@ case class Cast(child: Expression, dataType: DataType, 
timeZoneId: Option[String
     })
   }
 
-  private[this] def cast(from: DataType, to: DataType): Any => Any = to match {
-    case dt if dt == from => identity[Any]
-    case StringType => castToString(from)
-    case BinaryType => castToBinary(from)
-    case DateType => castToDate(from)
-    case decimal: DecimalType => castToDecimal(from, decimal)
-    case TimestampType => castToTimestamp(from)
-    case CalendarIntervalType => castToInterval(from)
-    case BooleanType => castToBoolean(from)
-    case ByteType => castToByte(from)
-    case ShortType => castToShort(from)
-    case IntegerType => castToInt(from)
-    case FloatType => castToFloat(from)
-    case LongType => castToLong(from)
-    case DoubleType => castToDouble(from)
-    case array: ArrayType => 
castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
-    case map: MapType => castMap(from.asInstanceOf[MapType], map)
-    case struct: StructType => castStruct(from.asInstanceOf[StructType], 
struct)
-    case udt: UserDefinedType[_]
-      if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass =>
-      identity[Any]
-    case _: UserDefinedType[_] =>
-      throw new SparkException(s"Cannot cast $from to $to.")
+  private[this] def cast(from: DataType, to: DataType): Any => Any = {
+    // If the cast does not change the structure, then we don't really need to 
cast anything.
+    // We can return what the children return. Same thing should happen in the 
codegen path.
+    if (DataType.equalsStructurally(from, to)) {
+      identity
+    } else {
+      to match {
+        case dt if dt == from => identity[Any]
+        case StringType => castToString(from)
+        case BinaryType => castToBinary(from)
+        case DateType => castToDate(from)
+        case decimal: DecimalType => castToDecimal(from, decimal)
+        case TimestampType => castToTimestamp(from)
+        case CalendarIntervalType => castToInterval(from)
+        case BooleanType => castToBoolean(from)
+        case ByteType => castToByte(from)
+        case ShortType => castToShort(from)
+        case IntegerType => castToInt(from)
+        case FloatType => castToFloat(from)
+        case LongType => castToLong(from)
+        case DoubleType => castToDouble(from)
+        case array: ArrayType =>
+          castArray(from.asInstanceOf[ArrayType].elementType, 
array.elementType)
+        case map: MapType => castMap(from.asInstanceOf[MapType], map)
+        case struct: StructType => castStruct(from.asInstanceOf[StructType], 
struct)
+        case udt: UserDefinedType[_]
+          if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass 
=>
+          identity[Any]
+        case _: UserDefinedType[_] =>
+          throw new SparkException(s"Cannot cast $from to $to.")
+      }
+    }
   }
 
   private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
 
   protected override def nullSafeEval(input: Any): Any = cast(input)
 
+  override def genCode(ctx: CodegenContext): ExprCode = {
+    // If the cast does not change the structure, then we don't really need to 
cast anything.
+    // We can return what the children return. Same thing should happen in the 
interpreted path.
+    if (DataType.equalsStructurally(child.dataType, dataType)) {
+      child.genCode(ctx)
+    } else {
+      super.genCode(ctx)
+    }
+  }
+
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
     val eval = child.genCode(ctx)
     val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)

http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 520aff5..30745c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -288,4 +288,30 @@ object DataType {
       case (fromDataType, toDataType) => fromDataType == toDataType
     }
   }
+
+  /**
+   * Returns true if the two data types share the same "shape", i.e. the types 
(including
+   * nullability) are the same, but the field names don't need to be the same.
+   */
+  def equalsStructurally(from: DataType, to: DataType): Boolean = {
+    (from, to) match {
+      case (left: ArrayType, right: ArrayType) =>
+        equalsStructurally(left.elementType, right.elementType) &&
+          left.containsNull == right.containsNull
+
+      case (left: MapType, right: MapType) =>
+        equalsStructurally(left.keyType, right.keyType) &&
+          equalsStructurally(left.valueType, right.valueType) &&
+          left.valueContainsNull == right.valueContainsNull
+
+      case (StructType(fromFields), StructType(toFields)) =>
+        fromFields.length == toFields.length &&
+          fromFields.zip(toFields)
+            .forall { case (l, r) =>
+              equalsStructurally(l.dataType, r.dataType) && l.nullable == 
r.nullable
+            }
+
+      case (fromDataType, toDataType) => fromDataType == toDataType
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
index 8eccadb..a7ffa88 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala
@@ -813,4 +813,18 @@ class CastSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     assert(cast(1.0.toFloat, DateType).checkInputDataTypes().isFailure)
     assert(cast(1.0, DateType).checkInputDataTypes().isFailure)
   }
+
+  test("SPARK-20302 cast with same structure") {
+    val from = new StructType()
+      .add("a", IntegerType)
+      .add("b", new StructType().add("b1", LongType))
+
+    val to = new StructType()
+      .add("a1", IntegerType)
+      .add("b1", new StructType().add("b11", LongType))
+
+    val input = Row(10, Row(12L))
+
+    checkEvaluation(cast(Literal.create(input, from), to), input)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ffc57b01/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
index f078ef0..c4635c8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala
@@ -411,4 +411,35 @@ class DataTypeSuite extends SparkFunSuite {
   checkCatalogString(ArrayType(createStruct(40)))
   checkCatalogString(MapType(IntegerType, StringType))
   checkCatalogString(MapType(IntegerType, createStruct(40)))
+
+  def checkEqualsStructurally(from: DataType, to: DataType, expected: 
Boolean): Unit = {
+    val testName = s"equalsStructurally: (from: $from, to: $to)"
+    test(testName) {
+      assert(DataType.equalsStructurally(from, to) === expected)
+    }
+  }
+
+  checkEqualsStructurally(BooleanType, BooleanType, true)
+  checkEqualsStructurally(IntegerType, IntegerType, true)
+  checkEqualsStructurally(IntegerType, LongType, false)
+  checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, 
true), true)
+  checkEqualsStructurally(ArrayType(IntegerType, true), ArrayType(IntegerType, 
false), false)
+
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType),
+    new StructType().add("f2", IntegerType),
+    true)
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType),
+    new StructType().add("f2", IntegerType, false),
+    false)
+
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType).add("f", new 
StructType().add("f2", StringType)),
+    new StructType().add("f2", IntegerType).add("g", new 
StructType().add("f1", StringType)),
+    true)
+  checkEqualsStructurally(
+    new StructType().add("f1", IntegerType).add("f", new 
StructType().add("f2", StringType, false)),
+    new StructType().add("f2", IntegerType).add("g", new 
StructType().add("f1", StringType)),
+    false)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to