Repository: spark
Updated Branches:
  refs/heads/master 61d4c07f4 -> 84f81e035


[SPARK-10310] [SQL] Fixes script transformation field/line delimiters

**Please attribute this PR to `Zhichao Li <zhichao.liintel.com>`.**

This PR is based on PR #8476 authored by zhichao-li. It fixes SPARK-10310 by 
adding field delimiter SerDe property to the default `LazySimpleSerDe`, and 
enabling default record reader/writer classes.

Currently, we only support `LazySimpleSerDe`, used together with 
`TextRecordReader` and `TextRecordWriter`, and don't support customizing record 
reader/writer using `RECORDREADER`/`RECORDWRITER` clauses. This should be 
addressed in separate PR(s).

Author: Cheng Lian <[email protected]>

Closes #8860 from liancheng/spark-10310/fix-script-trans-delimiters.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/84f81e03
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/84f81e03
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/84f81e03

Branch: refs/heads/master
Commit: 84f81e035e1dab1b42c36563041df6ba16e7b287
Parents: 61d4c07
Author: Zhichao Li <[email protected]>
Authored: Tue Sep 22 19:41:57 2015 -0700
Committer: Yin Huai <[email protected]>
Committed: Tue Sep 22 19:41:57 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/HiveQl.scala      | 52 +++++++++++---
 .../hive/execution/ScriptTransformation.scala   | 75 ++++++++++++++++----
 .../resources/data/scripts/test_transform.py    |  6 ++
 .../sql/hive/execution/SQLQuerySuite.scala      | 39 ++++++++++
 .../execution/ScriptTransformationSuite.scala   |  2 +
 5 files changed, 152 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/84f81e03/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index d5cd7e9..256440a 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -32,6 +32,7 @@ import org.apache.hadoop.hive.ql.lib.Node
 import org.apache.hadoop.hive.ql.parse._
 import org.apache.hadoop.hive.ql.plan.PlanUtils
 import org.apache.hadoop.hive.ql.session.SessionState
+import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
 
 import org.apache.spark.Logging
 import org.apache.spark.sql.AnalysisException
@@ -884,16 +885,22 @@ 
https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                   AttributeReference("value", StringType)()), true)
             }
 
-            def matchSerDe(clause: Seq[ASTNode])
-              : (Seq[(String, String)], Option[String], Seq[(String, String)]) 
= clause match {
+            type SerDeInfo = (
+              Seq[(String, String)],  // Input row format information
+              Option[String],         // Optional input SerDe class
+              Seq[(String, String)],  // Input SerDe properties
+              Boolean                 // Whether to use default record 
reader/writer
+            )
+
+            def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match {
               case Token("TOK_SERDEPROPS", propsClause) :: Nil =>
                 val rowFormat = propsClause.map {
                   case Token(name, Token(value, Nil) :: Nil) => (name, value)
                 }
-                (rowFormat, None, Nil)
+                (rowFormat, None, Nil, false)
 
               case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: 
Nil =>
-                (Nil, 
Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil)
+                (Nil, 
Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false)
 
               case Token("TOK_SERDENAME", Token(serdeClass, Nil) ::
                 Token("TOK_TABLEPROPERTIES",
@@ -903,20 +910,47 @@ 
https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
                     (BaseSemanticAnalyzer.unescapeSQLString(name),
                       BaseSemanticAnalyzer.unescapeSQLString(value))
                 }
-                (Nil, 
Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), serdeProps)
 
-              case Nil => (Nil, 
Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), Nil)
+                // SPARK-10310: Special cases LazySimpleSerDe
+                // TODO Fully supports user-defined record reader/writer 
classes
+                val unescapedSerDeClass = 
BaseSemanticAnalyzer.unescapeSQLString(serdeClass)
+                val useDefaultRecordReaderWriter =
+                  unescapedSerDeClass == 
classOf[LazySimpleSerDe].getCanonicalName
+                (Nil, Some(unescapedSerDeClass), serdeProps, 
useDefaultRecordReaderWriter)
+
+              case Nil =>
+                // Uses default TextRecordReader/TextRecordWriter, sets field 
delimiter here
+                val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t")
+                (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), 
serdeProps, true)
             }
 
-            val (inRowFormat, inSerdeClass, inSerdeProps) = 
matchSerDe(inputSerdeClause)
-            val (outRowFormat, outSerdeClass, outSerdeProps) = 
matchSerDe(outputSerdeClause)
+            val (inRowFormat, inSerdeClass, inSerdeProps, 
useDefaultRecordReader) =
+              matchSerDe(inputSerdeClause)
+
+            val (outRowFormat, outSerdeClass, outSerdeProps, 
useDefaultRecordWriter) =
+              matchSerDe(outputSerdeClause)
 
             val unescapedScript = 
BaseSemanticAnalyzer.unescapeSQLString(script)
 
+            // TODO Adds support for user-defined record reader/writer classes
+            val recordReaderClass = if (useDefaultRecordReader) {
+              Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER))
+            } else {
+              None
+            }
+
+            val recordWriterClass = if (useDefaultRecordWriter) {
+              Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER))
+            } else {
+              None
+            }
+
             val schema = HiveScriptIOSchema(
               inRowFormat, outRowFormat,
               inSerdeClass, outSerdeClass,
-              inSerdeProps, outSerdeProps, schemaLess)
+              inSerdeProps, outSerdeProps,
+              recordReaderClass, recordWriterClass,
+              schemaLess)
 
             Some(
               logical.ScriptTransformation(

http://git-wip-us.apache.org/repos/asf/spark/blob/84f81e03/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 32bddba..b30117f 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -24,20 +24,22 @@ import javax.annotation.Nullable
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
 
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter}
 import org.apache.hadoop.hive.serde.serdeConstants
 import org.apache.hadoop.hive.serde2.AbstractSerDe
 import org.apache.hadoop.hive.serde2.objectinspector._
 import org.apache.hadoop.io.Writable
 
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.hive.HiveShim._
 import org.apache.spark.sql.hive.{HiveContext, HiveInspectors}
 import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils}
+import org.apache.spark.util.{CircularBuffer, RedirectThread, 
SerializableConfiguration, Utils}
 import org.apache.spark.{Logging, TaskContext}
 
 /**
@@ -58,6 +60,8 @@ case class ScriptTransformation(
 
   override def otherCopyArgs: Seq[HiveContext] = sc :: Nil
 
+  private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
+
   protected override def doExecute(): RDD[InternalRow] = {
     def processIterator(inputIterator: Iterator[InternalRow]): 
Iterator[InternalRow] = {
       val cmd = List("/bin/bash", "-c", script)
@@ -67,6 +71,7 @@ case class ScriptTransformation(
       val inputStream = proc.getInputStream
       val outputStream = proc.getOutputStream
       val errorStream = proc.getErrorStream
+      val localHiveConf = serializedHiveConf.value
 
       // In order to avoid deadlocks, we need to consume the error output of 
the child process.
       // To avoid issues caused by large error output, we use a circular 
buffer to limit the amount
@@ -96,7 +101,8 @@ case class ScriptTransformation(
         outputStream,
         proc,
         stderrBuffer,
-        TaskContext.get()
+        TaskContext.get(),
+        localHiveConf
       )
 
       // This nullability is a performance optimization in order to avoid an 
Option.foreach() call
@@ -109,6 +115,10 @@ case class ScriptTransformation(
       val outputIterator: Iterator[InternalRow] = new Iterator[InternalRow] 
with HiveInspectors {
         var curLine: String = null
         val scriptOutputStream = new DataInputStream(inputStream)
+
+        @Nullable val scriptOutputReader =
+          ioschema.recordReader(scriptOutputStream, localHiveConf).orNull
+
         var scriptOutputWritable: Writable = null
         val reusedWritableObject: Writable = if (null != outputSerde) {
           outputSerde.getSerializedClass().newInstance
@@ -134,15 +144,25 @@ case class ScriptTransformation(
             }
           } else if (scriptOutputWritable == null) {
             scriptOutputWritable = reusedWritableObject
-            try {
-              scriptOutputWritable.readFields(scriptOutputStream)
-              true
-            } catch {
-              case _: EOFException =>
-                if (writerThread.exception.isDefined) {
-                  throw writerThread.exception.get
-                }
+
+            if (scriptOutputReader != null) {
+              if (scriptOutputReader.next(scriptOutputWritable) <= 0) {
+                writerThread.exception.foreach(throw _)
                 false
+              } else {
+                true
+              }
+            } else {
+              try {
+                scriptOutputWritable.readFields(scriptOutputStream)
+                true
+              } catch {
+                case _: EOFException =>
+                  if (writerThread.exception.isDefined) {
+                    throw writerThread.exception.get
+                  }
+                  false
+              }
             }
           } else {
             true
@@ -210,7 +230,8 @@ private class ScriptTransformationWriterThread(
     outputStream: OutputStream,
     proc: Process,
     stderrBuffer: CircularBuffer,
-    taskContext: TaskContext
+    taskContext: TaskContext,
+    conf: Configuration
   ) extends Thread("Thread-ScriptTransformation-Feed") with Logging {
 
   setDaemon(true)
@@ -224,6 +245,7 @@ private class ScriptTransformationWriterThread(
     TaskContext.setTaskContext(taskContext)
 
     val dataOutputStream = new DataOutputStream(outputStream)
+    @Nullable val scriptInputWriter = ioschema.recordWriter(dataOutputStream, 
conf).orNull
 
     // We can't use Utils.tryWithSafeFinally here because we also need a 
`catch` block, so
     // let's use a variable to record whether the `finally` block was hit due 
to an exception
@@ -250,7 +272,12 @@ private class ScriptTransformationWriterThread(
         } else {
           val writable = inputSerde.serialize(
             row.asInstanceOf[GenericInternalRow].values, inputSoi)
-          prepareWritable(writable, 
ioschema.outputSerdeProps).write(dataOutputStream)
+
+          if (scriptInputWriter != null) {
+            scriptInputWriter.write(writable)
+          } else {
+            prepareWritable(writable, 
ioschema.outputSerdeProps).write(dataOutputStream)
+          }
         }
       }
       outputStream.close()
@@ -290,6 +317,8 @@ case class HiveScriptIOSchema (
     outputSerdeClass: Option[String],
     inputSerdeProps: Seq[(String, String)],
     outputSerdeProps: Seq[(String, String)],
+    recordReaderClass: Option[String],
+    recordWriterClass: Option[String],
     schemaLess: Boolean) extends ScriptInputOutputSchema with HiveInspectors {
 
   private val defaultFormat = Map(
@@ -347,4 +376,24 @@ case class HiveScriptIOSchema (
 
     serde
   }
+
+  def recordReader(
+      inputStream: InputStream,
+      conf: Configuration): Option[RecordReader] = {
+    recordReaderClass.map { klass =>
+      val instance = 
Utils.classForName(klass).newInstance().asInstanceOf[RecordReader]
+      val props = new Properties()
+      props.putAll(outputSerdeProps.toMap.asJava)
+      instance.initialize(inputStream, conf, props)
+      instance
+    }
+  }
+
+  def recordWriter(outputStream: OutputStream, conf: Configuration): 
Option[RecordWriter] = {
+    recordWriterClass.map { klass =>
+      val instance = 
Utils.classForName(klass).newInstance().asInstanceOf[RecordWriter]
+      instance.initialize(outputStream, conf)
+      instance
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/84f81e03/sql/hive/src/test/resources/data/scripts/test_transform.py
----------------------------------------------------------------------
diff --git a/sql/hive/src/test/resources/data/scripts/test_transform.py 
b/sql/hive/src/test/resources/data/scripts/test_transform.py
new file mode 100755
index 0000000..ac6d11d
--- /dev/null
+++ b/sql/hive/src/test/resources/data/scripts/test_transform.py
@@ -0,0 +1,6 @@
+import sys
+
+delim = sys.argv[1]
+
+for row in sys.stdin:
+    print(delim.join([w + '#' for w in row[:-1].split(delim)]))

http://git-wip-us.apache.org/repos/asf/spark/blob/84f81e03/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index bb02473..71823e3 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1184,4 +1184,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils 
with TestHiveSingleton {
 
     checkAnswer(df, Row("text inside layer 2") :: Nil)
   }
+
+  test("SPARK-10310: " +
+    "script transformation using default input/output SerDe and record 
reader/writer") {
+    sqlContext
+      .range(5)
+      .selectExpr("id AS a", "id AS b")
+      .registerTempTable("test")
+
+    checkAnswer(
+      sql(
+        """FROM(
+          |  FROM test SELECT TRANSFORM(a, b)
+          |  USING 'python src/test/resources/data/scripts/test_transform.py 
"\t"'
+          |  AS (c STRING, d STRING)
+          |) t
+          |SELECT c
+        """.stripMargin),
+      (0 until 5).map(i => Row(i + "#")))
+  }
+
+  test("SPARK-10310: script transformation using LazySimpleSerDe") {
+    sqlContext
+      .range(5)
+      .selectExpr("id AS a", "id AS b")
+      .registerTempTable("test")
+
+    val df = sql(
+      """FROM test
+        |SELECT TRANSFORM(a, b)
+        |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+        |WITH SERDEPROPERTIES('field.delim' = '|')
+        |USING 'python src/test/resources/data/scripts/test_transform.py "|"'
+        |AS (c STRING, d STRING)
+        |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
+        |WITH SERDEPROPERTIES('field.delim' = '|')
+      """.stripMargin)
+
+    checkAnswer(df, (0 until 5).map(i => Row(i + "#", i + "#")))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/84f81e03/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
index cb8d0fc..7cfdb88 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -38,6 +38,8 @@ class ScriptTransformationSuite extends SparkPlanTest with 
TestHiveSingleton {
     outputSerdeClass = None,
     inputSerdeProps = Seq.empty,
     outputSerdeProps = Seq.empty,
+    recordReaderClass = None,
+    recordWriterClass = None,
     schemaLess = false
   )
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to