Repository: incubator-hivemall
Updated Branches:
  refs/heads/master f7fc3041f -> cb63532aa


Close #62: [HIVEMALL-89][SQL] Support to_from/from_csv in HivemallOps


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/cb63532a
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/cb63532a
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/cb63532a

Branch: refs/heads/master
Commit: cb63532aa117b22092ced6b116ed2e4047cae447
Parents: f7fc304
Author: Takeshi Yamamuro <[email protected]>
Authored: Fri Mar 17 00:16:07 2017 +0900
Committer: Takeshi Yamamuro <[email protected]>
Committed: Fri Mar 17 00:16:07 2017 +0900

----------------------------------------------------------------------
 docs/gitbook/spark/misc/functions.md            |  71 ++++++++-
 .../datasources/csv/csvExpressions.scala        | 150 +++++++++++++++++++
 .../org/apache/spark/sql/hive/HivemallOps.scala |  48 ++++++
 .../spark/sql/hive/HivemallOpsSuite.scala       |  19 +++
 4 files changed, 287 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/docs/gitbook/spark/misc/functions.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/spark/misc/functions.md 
b/docs/gitbook/spark/misc/functions.md
index 23763dd..fdc2292 100644
--- a/docs/gitbook/spark/misc/functions.md
+++ b/docs/gitbook/spark/misc/functions.md
@@ -17,9 +17,12 @@
   under the License.
 -->
 
+flatten
+================
+
 `df.flatten()` flattens a nested schema of `df` into a flat one.
 
-# Usage
+## Usage
 
 ```scala
 scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF()
@@ -45,3 +48,69 @@ root
  |-- _3$_2: double (nullable = true)
 ```
 
+from_csv
+================
+
+This function parses a column containing a CSV string into a `StructType`
+with the specified schema.
+
+## Usage
+
+```scala
+scala> val df = Seq("1, abc, 0.8").toDF()
+
+scala> df.printSchema
+root
+ |-- value: string (nullable = true)
+
+scala> val schema = new StructType().add("a", IntegerType).add("b", 
StringType).add("c", DoubleType)
+
+scala> df.select(from_csv($"value", schema)).printSchema
+root
+ |-- csvtostruct(value): struct (nullable = true)
+ |    |-- a: integer (nullable = true)
+ |    |-- b: string (nullable = true)
+ |    |-- c: double (nullable = true)
+
+scala> df.select(from_csv($"value", schema)).show
++------------------+
+|csvtostruct(value)|
++------------------+
+|      [1, abc,0.8]|
++------------------+
+```
+
+to_csv
+================
+
+This function converts a column containing a `StructType` into a CSV string
+with the specified schema.
+
+## Usage
+
+```scala
+scala> val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, 
"def"))).toDF()
+
+scala> df.printSchema
+root
+ |-- _1: integer (nullable = false)
+ |-- _2: string (nullable = true)
+ |-- _3: struct (nullable = true)
+ |    |-- _1: integer (nullable = false)
+ |    |-- _2: double (nullable = false)
+ |    |-- _3: string (nullable = true)
+
+scala> df.select(to_csv($"_3"))
+
+scala> df.select(to_csv($"_3")).printSchema
+root
+ |-- structtocsv(_3): string (nullable = true)
+
+scala> df.select(to_csv($"_3")).show
++---------------+
+|structtocsv(_3)|
++---------------+
+|      0,3.9,abc|
+|      2,0.4,def|
++---------------+
+```

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
new file mode 100644
index 0000000..abc4c87
--- /dev/null
+++ 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.execution.datasources.csv
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, 
Expression, UnaryExpression}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * Converts a csv input string to a [[StructType]] with the specified schema.
+ *
+ * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in 
Spark-v2.2+
+ */
+case class CsvToStruct(schema: StructType, options: Map[String, String], 
child: Expression)
+  extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
+  override def nullable: Boolean = true
+
+  @transient private val csvOptions = new CSVOptions(options)
+  @transient private val csvReader = new CsvReader(csvOptions)
+  @transient private val csvParser = CSVRelation.csvParser(schema, 
schema.fieldNames, csvOptions)
+
+  private def parse(s: String): InternalRow = {
+    csvParser(csvReader.parseLine(s), 0).orNull
+  }
+
+  override def dataType: DataType = schema
+
+  override def nullSafeEval(csv: Any): Any = {
+    try parse(csv.toString) catch { case _: RuntimeException => null }
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = StringType :: Nil
+}
+
+/**
+ * Converts a [[StructType]] to a csv output string.
+ */
+case class StructToCsv(
+    options: Map[String, String],
+    child: Expression)
+  extends UnaryExpression with CodegenFallback with ExpectsInputTypes {
+  override def nullable: Boolean = true
+
+  @transient
+  lazy val params = new CSVOptions(options)
+
+  @transient
+  lazy val dataSchema = child.dataType.asInstanceOf[StructType]
+
+  @transient
+  lazy val writer = new LineCsvWriter(params, dataSchema.fieldNames.toSeq)
+
+  override def dataType: DataType = StringType
+
+  // A `ValueConverter` is responsible for converting a value of an 
`InternalRow` to `String`.
+  // When the value is null, this converter should not be called.
+  private type ValueConverter = (InternalRow, Int) => String
+
+  // `ValueConverter`s for all values in the fields of the schema
+  private lazy val valueConverters: Array[ValueConverter] =
+    dataSchema.map(_.dataType).map(makeConverter).toArray
+
+  private def verifySchema(schema: StructType): Unit = {
+    def verifyType(dataType: DataType): Unit = dataType match {
+      case ByteType | ShortType | IntegerType | LongType | FloatType |
+           DoubleType | BooleanType | _: DecimalType | TimestampType |
+           DateType | StringType =>
+
+      case udt: UserDefinedType[_] => verifyType(udt.sqlType)
+
+      case _ =>
+        throw new UnsupportedOperationException(
+          s"CSV data source does not support ${dataType.simpleString} data 
type.")
+    }
+
+    schema.foreach(field => verifyType(field.dataType))
+  }
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    if (StructType.acceptsType(child.dataType)) {
+      try {
+        verifySchema(child.dataType.asInstanceOf[StructType])
+        TypeCheckResult.TypeCheckSuccess
+      } catch {
+        case e: UnsupportedOperationException =>
+          TypeCheckResult.TypeCheckFailure(e.getMessage)
+      }
+    } else {
+      TypeCheckResult.TypeCheckFailure(
+        s"$prettyName requires that the expression is a struct expression.")
+    }
+  }
+
+  private def rowToString(row: InternalRow): Seq[String] = {
+    var i = 0
+    val values = new Array[String](row.numFields)
+    while (i < row.numFields) {
+      if (!row.isNullAt(i)) {
+        values(i) = valueConverters(i).apply(row, i)
+      } else {
+        values(i) = params.nullValue
+      }
+      i += 1
+    }
+    values
+  }
+
+  private def makeConverter(dataType: DataType): ValueConverter = dataType 
match {
+    case DateType =>
+      (row: InternalRow, ordinal: Int) =>
+        params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
+
+    case TimestampType =>
+      (row: InternalRow, ordinal: Int) =>
+        
params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
+
+    case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
+
+    case dt: DataType =>
+      (row: InternalRow, ordinal: Int) =>
+        row.get(ordinal, dt).toString
+  }
+
+  override def nullSafeEval(row: Any): Any = {
+    writer.writeRow(rowToString(row.asInstanceOf[InternalRow]), false)
+    UTF8String.fromString(writer.flush())
+  }
+
+  override def inputTypes: Seq[AbstractDataType] = StructType :: Nil
+}

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index d7fa202..83129a7 100644
--- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, 
LogicalPlan}
 import org.apache.spark.sql.execution.UserProvidedPlanner
+import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, 
StructToCsv}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -1483,6 +1484,53 @@ object HivemallOps {
   }.as("rowid")
 
   /**
+   * Parses a column containing a CSV string into a [[StructType]] with the 
specified schema.
+   * Returns `null`, in the case of an unparseable string.
+   * @group misc
+   *
+   * @param e a string column containing CSV data.
+   * @param schema the schema to use when parsing the csv string
+   * @param options options to control how the csv is parsed. accepts the same 
options and the
+   *                csv data source.
+   */
+  def from_csv(e: Column, schema: StructType, options: Map[String, String]): 
Column = withExpr {
+    CsvToStruct(schema, options, e.expr)
+  }
+
+  /**
+   * Parses a column containing a CSV string into a [[StructType]] with the 
specified schema.
+   * Returns `null`, in the case of an unparseable string.
+   * @group misc
+   *
+   * @param e a string column containing CSV data.
+   * @param schema the schema to use when parsing the json string
+   */
+  def from_csv(e: Column, schema: StructType): Column =
+    from_csv(e, schema, Map.empty[String, String])
+
+  /**
+   * Converts a column containing a [[StructType]] into a CSV string with the 
specified schema.
+   * Throws an exception, in the case of an unsupported type.
+   * @group misc
+   *
+   * @param e a struct column.
+   * @param options options to control how the struct column is converted into 
a json string.
+   *                accepts the same options and the json data source.
+   */
+  def to_csv(e: Column, options: Map[String, String]): Column = withExpr {
+    StructToCsv(options, e.expr)
+  }
+
+  /**
+   * Converts a column containing a [[StructType]] into a CSV string with the 
specified schema.
+   * Throws an exception, in the case of an unsupported type.
+   * @group misc
+   *
+   * @param e a struct column.
+   */
+  def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String])
+
+  /**
    * A convenient function to wrap an expression and produce a Column.
    */
   @inline private def withExpr(expr: Expression): Column = Column(expr)

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index 74b2093..1547227 100644
--- 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -487,6 +487,25 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
     assert(errMsg2.getMessage.startsWith("Separator cannot be more than one 
character:"))
   }
 
+  test("misc - from_csv") {
+    import hiveContext.implicits._
+    val df = Seq("""1,abc""").toDF()
+    val schema = new StructType().add("a", IntegerType).add("b", StringType)
+    checkAnswer(
+      df.select(from_csv($"value", schema)),
+      Row(Row(1, "abc")) :: Nil)
+  }
+
+  test("misc - to_csv") {
+    import hiveContext.implicits._
+    val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF()
+    checkAnswer(
+      df.select(to_csv($"_3")),
+      Row("0,3.9,abc") ::
+      Row("2,0.4,def") ::
+      Nil)
+  }
+
   /**
    * This test fails because;
    *

Reply via email to