Repository: spark
Updated Branches:
  refs/heads/master 6a6010f00 -> 0a9c02759


[SPARK-15956][SQL] When unwrapping ORC avoid pattern matching at runtime

## What changes were proposed in this pull request?

Extend the returning of unwrapper functions from primitive types to all types.

## How was this patch tested?

The patch should pass all unit tests. Reading ORC files with non-primitive 
types with this change reduced the read time by ~15%.

===

The github diff is very noisy. Attaching the screenshots below for improved 
readability:

![screen shot 2016-06-14 at 5 33 16 
pm](https://cloud.githubusercontent.com/assets/1514239/16064580/4d6f7a98-3257-11e6-9172-65e4baff948b.png)

![screen shot 2016-06-14 at 5 33 28 
pm](https://cloud.githubusercontent.com/assets/1514239/16064587/5ae6c244-3257-11e6-8460-69eee70de219.png)

Author: Brian Cho <[email protected]>

Closes #13676 from dafrista/improve-orc-master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0a9c0275
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0a9c0275
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0a9c0275

Branch: refs/heads/master
Commit: 0a9c02759515c41de37db6381750bc3a316c860c
Parents: 6a6010f
Author: Brian Cho <[email protected]>
Authored: Wed Jun 22 10:38:42 2016 -0700
Committer: Herman van Hovell <[email protected]>
Committed: Wed Jun 22 10:38:42 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/hive/HiveInspectors.scala  | 428 +++++++++++++------
 .../org/apache/spark/sql/hive/TableReader.scala |   3 +-
 .../hive/execution/ScriptTransformation.scala   |   6 +-
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |  21 +-
 .../spark/sql/hive/HiveInspectorSuite.scala     |   6 +
 5 files changed, 314 insertions(+), 150 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0a9c0275/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
index 585befe..1aadc8b 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -239,145 +239,6 @@ private[hive] trait HiveInspectors {
   }
 
   /**
-   * Converts hive types to native catalyst types.
-   * @param data the data in Hive type
-   * @param oi   the ObjectInspector associated with the Hive Type
-   * @return     convert the data into catalyst type
-   * TODO return the function of (data => Any) instead for performance 
consideration
-   *
-   * Strictly follows the following order in unwrapping (constant OI has the 
higher priority):
-   *  Constant Null object inspector =>
-   *    return null
-   *  Constant object inspector =>
-   *    extract the value from constant object inspector
-   *  Check whether the `data` is null =>
-   *    return null if true
-   *  If object inspector prefers writable =>
-   *    extract writable from `data` and then get the catalyst type from the 
writable
-   *  Extract the java object directly from the object inspector
-   *
-   *  NOTICE: the complex data type requires recursive unwrapping.
-   */
-  def unwrap(data: Any, oi: ObjectInspector): Any = oi match {
-    case coi: ConstantObjectInspector if coi.getWritableConstantValue == null 
=> null
-    case poi: WritableConstantStringObjectInspector =>
-      UTF8String.fromString(poi.getWritableConstantValue.toString)
-    case poi: WritableConstantHiveVarcharObjectInspector =>
-      
UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
-    case poi: WritableConstantHiveCharObjectInspector =>
-      UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue)
-    case poi: WritableConstantHiveDecimalObjectInspector =>
-      HiveShim.toCatalystDecimal(
-        PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
-        poi.getWritableConstantValue.getHiveDecimal)
-    case poi: WritableConstantTimestampObjectInspector =>
-      val t = poi.getWritableConstantValue
-      t.getSeconds * 1000000L + t.getNanos / 1000L
-    case poi: WritableConstantIntObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantDoubleObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantBooleanObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantLongObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantFloatObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantShortObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantByteObjectInspector =>
-      poi.getWritableConstantValue.get()
-    case poi: WritableConstantBinaryObjectInspector =>
-      val writable = poi.getWritableConstantValue
-      val temp = new Array[Byte](writable.getLength)
-      System.arraycopy(writable.getBytes, 0, temp, 0, temp.length)
-      temp
-    case poi: WritableConstantDateObjectInspector =>
-      DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
-    case mi: StandardConstantMapObjectInspector =>
-      // take the value from the map inspector object, rather than the input 
data
-      val keyValues = mi.getWritableConstantValue.asScala.toSeq
-      val keys = keyValues.map(kv => unwrap(kv._1, 
mi.getMapKeyObjectInspector)).toArray
-      val values = keyValues.map(kv => unwrap(kv._2, 
mi.getMapValueObjectInspector)).toArray
-      ArrayBasedMapData(keys, values)
-    case li: StandardConstantListObjectInspector =>
-      // take the value from the list inspector object, rather than the input 
data
-      val values = li.getWritableConstantValue.asScala
-        .map(unwrap(_, li.getListElementObjectInspector))
-        .toArray
-      new GenericArrayData(values)
-    // if the value is null, we don't care about the object inspector type
-    case _ if data == null => null
-    case poi: VoidObjectInspector => null // always be null for void object 
inspector
-    case pi: PrimitiveObjectInspector => pi match {
-      // We think HiveVarchar/HiveChar is also a String
-      case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
-        
UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
-      case hvoi: HiveVarcharObjectInspector =>
-        UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
-      case hvoi: HiveCharObjectInspector if hvoi.preferWritable() =>
-        
UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue)
-      case hvoi: HiveCharObjectInspector =>
-        UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
-      case x: StringObjectInspector if x.preferWritable() =>
-        // Text is in UTF-8 already. No need to convert again via fromString. 
Copy bytes
-        val wObj = x.getPrimitiveWritableObject(data)
-        val result = wObj.copyBytes()
-        UTF8String.fromBytes(result, 0, result.length)
-      case x: StringObjectInspector =>
-        UTF8String.fromString(x.getPrimitiveJavaObject(data))
-      case x: IntObjectInspector if x.preferWritable() => x.get(data)
-      case x: BooleanObjectInspector if x.preferWritable() => x.get(data)
-      case x: FloatObjectInspector if x.preferWritable() => x.get(data)
-      case x: DoubleObjectInspector if x.preferWritable() => x.get(data)
-      case x: LongObjectInspector if x.preferWritable() => x.get(data)
-      case x: ShortObjectInspector if x.preferWritable() => x.get(data)
-      case x: ByteObjectInspector if x.preferWritable() => x.get(data)
-      case x: HiveDecimalObjectInspector => HiveShim.toCatalystDecimal(x, data)
-      case x: BinaryObjectInspector if x.preferWritable() =>
-        // BytesWritable.copyBytes() only available since Hadoop2
-        // In order to keep backward-compatible, we have to copy the
-        // bytes with old apis
-        val bw = x.getPrimitiveWritableObject(data)
-        val result = new Array[Byte](bw.getLength())
-        System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
-        result
-      case x: DateObjectInspector if x.preferWritable() =>
-        DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
-      case x: DateObjectInspector => 
DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
-      case x: TimestampObjectInspector if x.preferWritable() =>
-        val t = x.getPrimitiveWritableObject(data)
-        t.getSeconds * 1000000L + t.getNanos / 1000L
-      case ti: TimestampObjectInspector =>
-        DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
-      case _ => pi.getPrimitiveJavaObject(data)
-    }
-    case li: ListObjectInspector =>
-      Option(li.getList(data))
-        .map { l =>
-          val values = l.asScala.map(unwrap(_, 
li.getListElementObjectInspector)).toArray
-          new GenericArrayData(values)
-        }
-        .orNull
-    case mi: MapObjectInspector =>
-      val map = mi.getMap(data)
-      if (map == null) {
-        null
-      } else {
-        val keyValues = map.asScala.toSeq
-        val keys = keyValues.map(kv => unwrap(kv._1, 
mi.getMapKeyObjectInspector)).toArray
-        val values = keyValues.map(kv => unwrap(kv._2, 
mi.getMapValueObjectInspector)).toArray
-        ArrayBasedMapData(keys, values)
-      }
-    // currently, hive doesn't provide the ConstantStructObjectInspector
-    case si: StructObjectInspector =>
-      val allRefs = si.getAllStructFieldRefs
-      InternalRow.fromSeq(allRefs.asScala.map(
-        r => unwrap(si.getStructFieldData(data, r), 
r.getFieldObjectInspector)))
-  }
-
-
-  /**
    * Wraps with Hive types based on object inspector.
    * TODO: Consolidate all hive OI/data interface code.
    */
@@ -479,8 +340,292 @@ private[hive] trait HiveInspectors {
   }
 
   /**
-   * Builds specific unwrappers ahead of time according to object inspector
+   * Builds unwrappers ahead of time according to object inspector
    * types to avoid pattern matching and branching costs per row.
+   *
+   * Strictly follows the following order in unwrapping (constant OI has the 
higher priority):
+   * Constant Null object inspector =>
+   *   return null
+   * Constant object inspector =>
+   *   extract the value from constant object inspector
+   * If object inspector prefers writable =>
+   *   extract writable from `data` and then get the catalyst type from the 
writable
+   * Extract the java object directly from the object inspector
+   *
+   * NOTICE: the complex data type requires recursive unwrapping.
+   *
+   * @param objectInspector the ObjectInspector used to create an unwrapper.
+   * @return A function that unwraps data objects.
+   *         Use the overloaded HiveStructField version for in-place updating 
of a MutableRow.
+   */
+  def unwrapperFor(objectInspector: ObjectInspector): Any => Any =
+    objectInspector match {
+      case coi: ConstantObjectInspector if coi.getWritableConstantValue == 
null =>
+        _ => null
+      case poi: WritableConstantStringObjectInspector =>
+        val constant = 
UTF8String.fromString(poi.getWritableConstantValue.toString)
+        _ => constant
+      case poi: WritableConstantHiveVarcharObjectInspector =>
+        val constant = 
UTF8String.fromString(poi.getWritableConstantValue.getHiveVarchar.getValue)
+        _ => constant
+      case poi: WritableConstantHiveCharObjectInspector =>
+        val constant = 
UTF8String.fromString(poi.getWritableConstantValue.getHiveChar.getValue)
+        _ => constant
+      case poi: WritableConstantHiveDecimalObjectInspector =>
+        val constant = HiveShim.toCatalystDecimal(
+          PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector,
+          poi.getWritableConstantValue.getHiveDecimal)
+        _ => constant
+      case poi: WritableConstantTimestampObjectInspector =>
+        val t = poi.getWritableConstantValue
+        val constant = t.getSeconds * 1000000L + t.getNanos / 1000L
+        _ => constant
+      case poi: WritableConstantIntObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantDoubleObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantBooleanObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantLongObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantFloatObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantShortObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantByteObjectInspector =>
+        val constant = poi.getWritableConstantValue.get()
+        _ => constant
+      case poi: WritableConstantBinaryObjectInspector =>
+        val writable = poi.getWritableConstantValue
+        val constant = new Array[Byte](writable.getLength)
+        System.arraycopy(writable.getBytes, 0, constant, 0, constant.length)
+        _ => constant
+      case poi: WritableConstantDateObjectInspector =>
+        val constant = 
DateTimeUtils.fromJavaDate(poi.getWritableConstantValue.get())
+        _ => constant
+      case mi: StandardConstantMapObjectInspector =>
+        val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
+        val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
+        val keyValues = mi.getWritableConstantValue.asScala.toSeq
+        val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
+        val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
+        val constant = ArrayBasedMapData(keys, values)
+        _ => constant
+      case li: StandardConstantListObjectInspector =>
+        val unwrapper = unwrapperFor(li.getListElementObjectInspector)
+        val values = li.getWritableConstantValue.asScala
+          .map(unwrapper)
+          .toArray
+        val constant = new GenericArrayData(values)
+        _ => constant
+      case poi: VoidObjectInspector =>
+        _ => null // always be null for void object inspector
+      case pi: PrimitiveObjectInspector => pi match {
+        // We think HiveVarchar/HiveChar is also a String
+        case hvoi: HiveVarcharObjectInspector if hvoi.preferWritable() =>
+          data: Any => {
+            if (data != null) {
+              
UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveVarchar.getValue)
+            } else {
+              null
+            }
+          }
+        case hvoi: HiveVarcharObjectInspector =>
+          data: Any => {
+            if (data != null) {
+              UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
+            } else {
+              null
+            }
+          }
+        case hvoi: HiveCharObjectInspector if hvoi.preferWritable() =>
+          data: Any => {
+            if (data != null) {
+              
UTF8String.fromString(hvoi.getPrimitiveWritableObject(data).getHiveChar.getValue)
+            } else {
+              null
+            }
+          }
+        case hvoi: HiveCharObjectInspector =>
+          data: Any => {
+            if (data != null) {
+              UTF8String.fromString(hvoi.getPrimitiveJavaObject(data).getValue)
+            } else {
+              null
+            }
+          }
+        case x: StringObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) {
+              // Text is in UTF-8 already. No need to convert again via 
fromString. Copy bytes
+              val wObj = x.getPrimitiveWritableObject(data)
+              val result = wObj.copyBytes()
+              UTF8String.fromBytes(result, 0, result.length)
+            } else {
+              null
+            }
+          }
+        case x: StringObjectInspector =>
+          data: Any => {
+            if (data != null) {
+              UTF8String.fromString(x.getPrimitiveJavaObject(data))
+            } else {
+              null
+            }
+          }
+        case x: IntObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: BooleanObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: FloatObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: DoubleObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: LongObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: ShortObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: ByteObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) x.get(data) else null
+          }
+        case x: HiveDecimalObjectInspector =>
+          data: Any => {
+            if (data != null) {
+              HiveShim.toCatalystDecimal(x, data)
+            } else {
+              null
+            }
+          }
+        case x: BinaryObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) {
+              // BytesWritable.copyBytes() only available since Hadoop2
+              // In order to keep backward-compatible, we have to copy the
+              // bytes with old apis
+              val bw = x.getPrimitiveWritableObject(data)
+              val result = new Array[Byte](bw.getLength())
+              System.arraycopy(bw.getBytes(), 0, result, 0, bw.getLength())
+              result
+            } else {
+              null
+            }
+          }
+        case x: DateObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) {
+              
DateTimeUtils.fromJavaDate(x.getPrimitiveWritableObject(data).get())
+            } else {
+              null
+            }
+          }
+        case x: DateObjectInspector =>
+          data: Any => {
+            if (data != null) {
+              DateTimeUtils.fromJavaDate(x.getPrimitiveJavaObject(data))
+            } else {
+              null
+            }
+          }
+        case x: TimestampObjectInspector if x.preferWritable() =>
+          data: Any => {
+            if (data != null) {
+              val t = x.getPrimitiveWritableObject(data)
+              t.getSeconds * 1000000L + t.getNanos / 1000L
+            } else {
+              null
+            }
+          }
+        case ti: TimestampObjectInspector =>
+          data: Any => {
+            if (data != null) {
+              DateTimeUtils.fromJavaTimestamp(ti.getPrimitiveJavaObject(data))
+            } else {
+              null
+            }
+          }
+        case _ =>
+          data: Any => {
+            if (data != null) {
+              pi.getPrimitiveJavaObject(data)
+            } else {
+              null
+            }
+          }
+      }
+      case li: ListObjectInspector =>
+        val unwrapper = unwrapperFor(li.getListElementObjectInspector)
+        data: Any => {
+          if (data != null) {
+            Option(li.getList(data))
+              .map { l =>
+                val values = l.asScala.map(unwrapper).toArray
+                new GenericArrayData(values)
+              }
+              .orNull
+          } else {
+            null
+          }
+        }
+      case mi: MapObjectInspector =>
+        val keyUnwrapper = unwrapperFor(mi.getMapKeyObjectInspector)
+        val valueUnwrapper = unwrapperFor(mi.getMapValueObjectInspector)
+        data: Any => {
+          if (data != null) {
+            val map = mi.getMap(data)
+            if (map == null) {
+              null
+            } else {
+              val keyValues = map.asScala.toSeq
+              val keys = keyValues.map(kv => keyUnwrapper(kv._1)).toArray
+              val values = keyValues.map(kv => valueUnwrapper(kv._2)).toArray
+              ArrayBasedMapData(keys, values)
+            }
+          } else {
+            null
+          }
+        }
+      // currently, hive doesn't provide the ConstantStructObjectInspector
+      case si: StructObjectInspector =>
+        val fields = si.getAllStructFieldRefs.asScala
+        val fieldsToUnwrap = fields.zip(
+          fields.map(_.getFieldObjectInspector).map(unwrapperFor))
+        data: Any => {
+          if (data != null) {
+            InternalRow.fromSeq(fieldsToUnwrap.map { case (field, unwrapper) =>
+              unwrapper(si.getStructFieldData(data, field))
+            })
+          } else {
+            null
+          }
+        }
+    }
+
+  /**
+   * Builds unwrappers ahead of time according to object inspector
+   * types to avoid pattern matching and branching costs per row.
+   *
+   * @param field The HiveStructField to create an unwrapper for.
+   * @return A function that performs in-place updating of a MutableRow.
+   *         Use the overloaded ObjectInspector version for assignments.
    */
   def unwrapperFor(field: HiveStructField): (Any, MutableRow, Int) => Unit =
     field.getFieldObjectInspector match {
@@ -499,7 +644,8 @@ private[hive] trait HiveInspectors {
       case oi: DoubleObjectInspector =>
         (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, 
oi.get(value))
       case oi =>
-        (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = 
unwrap(value, oi)
+        val unwrapper = unwrapperFor(oi)
+        (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = 
unwrapper(value)
     }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/0a9c0275/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
index d044811..e49a235 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala
@@ -401,7 +401,8 @@ private[hive] object HadoopTableReader extends 
HiveInspectors with Logging {
           (value: Any, row: MutableRow, ordinal: Int) =>
             row.update(ordinal, oi.getPrimitiveJavaObject(value))
         case oi =>
-          (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = 
unwrap(value, oi)
+          val unwrapper = unwrapperFor(oi)
+          (value: Any, row: MutableRow, ordinal: Int) => row(ordinal) = 
unwrapper(value)
       }
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0a9c0275/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 9e25e1d..84990d3 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -127,6 +127,9 @@ case class ScriptTransformation(
         }
         val mutableRow = new SpecificMutableRow(output.map(_.dataType))
 
+        @transient
+        lazy val unwrappers = 
outputSoi.getAllStructFieldRefs.asScala.map(unwrapperFor)
+
         private def checkFailureAndPropagate(cause: Throwable = null): Unit = {
           if (writerThread.exception.isDefined) {
             throw writerThread.exception.get
@@ -215,13 +218,12 @@ case class ScriptTransformation(
             val raw = outputSerde.deserialize(scriptOutputWritable)
             scriptOutputWritable = null
             val dataList = outputSoi.getStructFieldsDataAsList(raw)
-            val fieldList = outputSoi.getAllStructFieldRefs()
             var i = 0
             while (i < dataList.size()) {
               if (dataList.get(i) == null) {
                 mutableRow.setNullAt(i)
               } else {
-                mutableRow(i) = unwrap(dataList.get(i), 
fieldList.get(i).getFieldObjectInspector)
+                unwrappers(i)(dataList.get(i), mutableRow, i)
               }
               i += 1
             }

http://git-wip-us.apache.org/repos/asf/spark/blob/0a9c0275/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index c536756..9347aeb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -71,8 +71,8 @@ private[hive] case class HiveSimpleUDF(
   override lazy val dataType = javaClassToDataType(method.getReturnType)
 
   @transient
-  lazy val returnInspector = 
ObjectInspectorFactory.getReflectionObjectInspector(
-    method.getGenericReturnType(), ObjectInspectorOptions.JAVA)
+  lazy val unwrapper = 
unwrapperFor(ObjectInspectorFactory.getReflectionObjectInspector(
+    method.getGenericReturnType(), ObjectInspectorOptions.JAVA))
 
   @transient
   private lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length)
@@ -87,7 +87,7 @@ private[hive] case class HiveSimpleUDF(
       method,
       function,
       conversionHelper.convertIfNecessary(inputs : _*): _*)
-    unwrap(ret, returnInspector)
+    unwrapper(ret)
   }
 
   override def toString: String = {
@@ -134,6 +134,9 @@ private[hive] case class HiveGenericUDF(
   }
 
   @transient
+  private lazy val unwrapper = unwrapperFor(returnInspector)
+
+  @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
     udfType != null && udfType.deterministic()
@@ -156,7 +159,7 @@ private[hive] case class HiveGenericUDF(
         .set(() => children(idx).eval(input))
       i += 1
     }
-    unwrap(function.evaluate(deferredObjects), returnInspector)
+    unwrapper(function.evaluate(deferredObjects))
   }
 
   override def prettyName: String = name
@@ -210,6 +213,9 @@ private[hive] case class HiveGenericUDTF(
   @transient
   private lazy val inputDataTypes: Array[DataType] = 
children.map(_.dataType).toArray
 
+  @transient
+  private lazy val unwrapper = unwrapperFor(outputInspector)
+
   override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
     outputInspector // Make sure initialized.
 
@@ -226,7 +232,7 @@ private[hive] case class HiveGenericUDTF(
       // We need to clone the input here because implementations of
       // GenericUDTF reuse the same object. Luckily they are always an array, 
so
       // it is easy to clone.
-      collected += unwrap(input, outputInspector).asInstanceOf[InternalRow]
+      collected += unwrapper(input).asInstanceOf[InternalRow]
     }
 
     def collectRows(): Seq[InternalRow] = {
@@ -293,9 +299,12 @@ private[hive] case class HiveUDAFFunction(
   private lazy val returnInspector = functionAndInspector._2
 
   @transient
+  private lazy val unwrapper = unwrapperFor(returnInspector)
+
+  @transient
   private[this] var buffer: GenericUDAFEvaluator.AggregationBuffer = _
 
-  override def eval(input: InternalRow): Any = 
unwrap(function.evaluate(buffer), returnInspector)
+  override def eval(input: InternalRow): Any = 
unwrapper(function.evaluate(buffer))
 
   @transient
   private lazy val inputProjection = new InterpretedProjection(children)

http://git-wip-us.apache.org/repos/asf/spark/blob/0a9c0275/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
index 3b867bb..bc51bcb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala
@@ -35,6 +35,12 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.Row
 
 class HiveInspectorSuite extends SparkFunSuite with HiveInspectors {
+
+  def unwrap(data: Any, oi: ObjectInspector): Any = {
+    val unwrapper = unwrapperFor(oi)
+    unwrapper(data)
+  }
+
   test("Test wrap SettableStructObjectInspector") {
     val udaf = new UDAFPercentile.PercentileLongEvaluator()
     udaf.init()


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to