Repository: spark
Updated Branches:
  refs/heads/branch-1.5 11d231159 -> f957c59b3


[SPARK-7119] [SQL] Give script a default serde with the user specific types

This is to address this issue that there would be not compatible type exception 
when running this:
`from (from src select transform(key, value) using 'cat' as (thing1 int, thing2 
string)) t select thing1 + 2;`

15/04/24 00:58:55 ERROR CliDriver: org.apache.spark.SparkException: Job aborted 
due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: 
Lost task 0.0 in stage 0.0 (TID 0, localhost): java.lang.ClassCastException: 
org.apache.spark.sql.types.UTF8String cannot be cast to java.lang.Integer
        at scala.runtime.BoxesRunTime.unboxToInt(BoxesRunTime.java:106)
        at scala.math.Numeric$IntIsIntegral$.plus(Numeric.scala:57)
        at 
org.apache.spark.sql.catalyst.expressions.Add.eval(arithmetic.scala:127)
        at 
org.apache.spark.sql.catalyst.expressions.Alias.eval(namedExpressions.scala:118)
        at 
org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection.apply(Projection.scala:68)
        at 
org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection.apply(Projection.scala:52)
        at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
        at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
        at scala.collection.Iterator$class.foreach(Iterator.scala:727)
        at scala.collection.AbstractIterator.foreach(Iterator.scala:1157)
        at 
scala.collection.generic.Growable$class.$plus$plus$eq(Growable.scala:48)
        at 
scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:103)
        at 
scala.collection.mutable.ArrayBuffer.$plus$plus$eq(ArrayBuffer.scala:47)
        at scala.collection.TraversableOnce$class.to(TraversableOnce.scala:273)
        at scala.collection.AbstractIterator.to(Iterator.scala:1157)
        at 
scala.collection.TraversableOnce$class.toBuffer(TraversableOnce.scala:265)
        at scala.collection.AbstractIterator.toBuffer(Iterator.scala:1157)
        at 
scala.collection.TraversableOnce$class.toArray(TraversableOnce.scala:252)
        at scala.collection.AbstractIterator.toArray(Iterator.scala:1157)
        at org.apache.spark.rdd.RDD$$anonfun$17.apply(RDD.scala:819)
        at org.apache.spark.rdd.RDD$$anonfun$17.apply(RDD.scala:819)
        at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1618)
        at 
org.apache.spark.SparkContext$$anonfun$runJob$5.apply(SparkContext.scala:1618)
        at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:63)
        at org.apache.spark.scheduler.Task.run(Task.scala:64)
        at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:209)
        at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1110)
        at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:603)
        at java.lang.Thread.run(Thread.java:722)

chenghao-intel marmbrus

Author: zhichao.li <[email protected]>

Closes #6638 from zhichao-li/transDataType2 and squashes the following commits:

a36cc7c [zhichao.li] style
b9252a8 [zhichao.li] delete cacheRow
f6968a4 [zhichao.li] give script a default serde

(cherry picked from commit 6f8f0e265a29e89bd5192a8d5217cba19f0875da)
Signed-off-by: Michael Armbrust <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f957c59b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f957c59b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f957c59b

Branch: refs/heads/branch-1.5
Commit: f957c59b3f7f8851103bb1e36d053dc1402ebb0c
Parents: 11d2311
Author: zhichao.li <[email protected]>
Authored: Tue Aug 4 18:26:05 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Tue Aug 4 18:26:18 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/HiveQl.scala      |  3 +-
 .../hive/execution/ScriptTransformation.scala   | 96 ++++++++------------
 .../sql/hive/execution/SQLQuerySuite.scala      | 10 ++
 3 files changed, 49 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f957c59b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index e2fdfc6..f43e403 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -21,6 +21,7 @@ import java.sql.Date
 import java.util.Locale
 
 import org.apache.hadoop.hive.conf.HiveConf
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
 import org.apache.hadoop.hive.serde.serdeConstants
 import org.apache.hadoop.hive.ql.{ErrorMsg, Context}
 import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo}
@@ -907,7 +908,7 @@ 
https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                 }
                 (Nil, 
Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
 
-              case Nil => (Nil, None, Nil)
+              case Nil => (Nil, 
Option(hiveConf().getVar(ConfVars.HIVESCRIPTSERDE)), Nil)
             }
 
             val (inRowFormat, inSerdeClass, inSerdeProps) = 
matchSerDe(inputSerdeClause)

http://git-wip-us.apache.org/repos/asf/spark/blob/f957c59b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index fbb8640..97e4ea2 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -27,11 +27,11 @@ import scala.util.control.NonFatal
 import org.apache.hadoop.hive.serde.serdeConstants
 import org.apache.hadoop.hive.serde2.AbstractSerDe
 import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.io.Writable
 
 import org.apache.spark.{TaskContext, Logging}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.CatalystTypeConverters
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
 import org.apache.spark.sql.execution._
@@ -106,9 +106,15 @@ case class ScriptTransformation(
 
       val reader = new BufferedReader(new InputStreamReader(inputStream))
       val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] 
with HiveInspectors {
-        var cacheRow: InternalRow = null
         var curLine: String = null
-        var eof: Boolean = false
+        val scriptOutputStream = new DataInputStream(inputStream)
+        var scriptOutputWritable: Writable = null
+        val reusedWritableObject: Writable = if (null != outputSerde) {
+          outputSerde.getSerializedClass().newInstance
+        } else {
+          null
+        }
+        val mutableRow = new SpecificMutableRow(output.map(_.dataType))
 
         override def hasNext: Boolean = {
           if (outputSerde == null) {
@@ -125,45 +131,20 @@ case class ScriptTransformation(
             } else {
               true
             }
-          } else {
-            if (eof) {
-              if (writerThread.exception.isDefined) {
-                throw writerThread.exception.get
-              }
-              false
-            } else {
+          } else if (scriptOutputWritable == null) {
+            scriptOutputWritable = reusedWritableObject
+            try {
+              scriptOutputWritable.readFields(scriptOutputStream)
               true
+            } catch {
+              case _: EOFException =>
+                if (writerThread.exception.isDefined) {
+                  throw writerThread.exception.get
+                }
+                false
             }
-          }
-        }
-
-        def deserialize(): InternalRow = {
-          if (cacheRow != null) return cacheRow
-
-          val mutableRow = new SpecificMutableRow(output.map(_.dataType))
-          try {
-            val dataInputStream = new DataInputStream(inputStream)
-            val writable = outputSerde.getSerializedClass().newInstance
-            writable.readFields(dataInputStream)
-
-            val raw = outputSerde.deserialize(writable)
-            val dataList = outputSoi.getStructFieldsDataAsList(raw)
-            val fieldList = outputSoi.getAllStructFieldRefs()
-
-            var i = 0
-            dataList.foreach( element => {
-              if (element == null) {
-                mutableRow.setNullAt(i)
-              } else {
-                mutableRow(i) = unwrap(element, 
fieldList(i).getFieldObjectInspector)
-              }
-              i += 1
-            })
-            mutableRow
-          } catch {
-            case e: EOFException =>
-              eof = true
-              null
+          } else {
+            true
           }
         }
 
@@ -171,7 +152,6 @@ case class ScriptTransformation(
           if (!hasNext) {
             throw new NoSuchElementException
           }
-
           if (outputSerde == null) {
             val prevLine = curLine
             curLine = reader.readLine()
@@ -185,12 +165,20 @@ case class ScriptTransformation(
                   .map(CatalystTypeConverters.convertToCatalyst))
             }
           } else {
-            val ret = deserialize()
-            if (!eof) {
-              cacheRow = null
-              cacheRow = deserialize()
+            val raw = outputSerde.deserialize(scriptOutputWritable)
+            scriptOutputWritable = null
+            val dataList = outputSoi.getStructFieldsDataAsList(raw)
+            val fieldList = outputSoi.getAllStructFieldRefs()
+            var i = 0
+            while (i < dataList.size()) {
+              if (dataList(i) == null) {
+                mutableRow.setNullAt(i)
+              } else {
+                mutableRow(i) = unwrap(dataList(i), 
fieldList(i).getFieldObjectInspector)
+              }
+              i += 1
             }
-            ret
+            mutableRow
           }
         }
       }
@@ -320,18 +308,8 @@ case class HiveScriptIOSchema (
   }
 
   private def parseAttrs(attrs: Seq[Expression]): (Seq[String], Seq[DataType]) 
= {
-    val columns = attrs.map {
-      case aref: AttributeReference => aref.name
-      case e: NamedExpression => e.name
-      case _ => null
-    }
-
-    val columnTypes = attrs.map {
-      case aref: AttributeReference => aref.dataType
-      case e: NamedExpression => e.dataType
-      case _ => null
-    }
-
+    val columns = attrs.zipWithIndex.map(e => s"${e._1.prettyName}_${e._2}")
+    val columnTypes = attrs.map(_.dataType)
     (columns, columnTypes)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/f957c59b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index fb41451..ff9a369 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -751,6 +751,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils {
         .queryExecution.toRdd.count())
   }
 
+  test("test script transform data type") {
+    val data = (1 to 5).map { i => (i, i) }
+    data.toDF("key", "value").registerTempTable("test")
+    checkAnswer(
+      sql("""FROM
+          |(FROM test SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, 
thing2 string)) t
+          |SELECT thing1 + 1
+        """.stripMargin), (2 to 6).map(i => Row(i)))
+  }
+
   test("window function: udaf with aggregate expressin") {
     val data = Seq(
       WindowData(1, "a", 5),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to