Repository: incubator-hivemall Updated Branches: refs/heads/master f7fc3041f -> cb63532aa
Close #62: [HIVEMALL-89][SQL] Support to_from/from_csv in HivemallOps Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/cb63532a Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/cb63532a Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/cb63532a Branch: refs/heads/master Commit: cb63532aa117b22092ced6b116ed2e4047cae447 Parents: f7fc304 Author: Takeshi Yamamuro <[email protected]> Authored: Fri Mar 17 00:16:07 2017 +0900 Committer: Takeshi Yamamuro <[email protected]> Committed: Fri Mar 17 00:16:07 2017 +0900 ---------------------------------------------------------------------- docs/gitbook/spark/misc/functions.md | 71 ++++++++- .../datasources/csv/csvExpressions.scala | 150 +++++++++++++++++++ .../org/apache/spark/sql/hive/HivemallOps.scala | 48 ++++++ .../spark/sql/hive/HivemallOpsSuite.scala | 19 +++ 4 files changed, 287 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/docs/gitbook/spark/misc/functions.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/spark/misc/functions.md b/docs/gitbook/spark/misc/functions.md index 23763dd..fdc2292 100644 --- a/docs/gitbook/spark/misc/functions.md +++ b/docs/gitbook/spark/misc/functions.md @@ -17,9 +17,12 @@ under the License. --> +flatten +================ + `df.flatten()` flattens a nested schema of `df` into a flat one. -# Usage +## Usage ```scala scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF() @@ -45,3 +48,69 @@ root |-- _3$_2: double (nullable = true) ``` +from_csv +================ + +This function parses a column containing a CSV string into a `StructType` +with the specified schema. + +## Usage + +```scala +scala> val df = Seq("1, abc, 0.8").toDF() + +scala> df.printSchema +root + |-- value: string (nullable = true) + +scala> val schema = new StructType().add("a", IntegerType).add("b", StringType).add("c", DoubleType) + +scala> df.select(from_csv($"value", schema)).printSchema +root + |-- csvtostruct(value): struct (nullable = true) + | |-- a: integer (nullable = true) + | |-- b: string (nullable = true) + | |-- c: double (nullable = true) + +scala> df.select(from_csv($"value", schema)).show ++------------------+ +|csvtostruct(value)| ++------------------+ +| [1, abc,0.8]| ++------------------+ +``` + +to_csv +================ + +This function converts a column containing a `StructType` into a CSV string +with the specified schema. + +## Usage + +```scala +scala> val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF() + +scala> df.printSchema +root + |-- _1: integer (nullable = false) + |-- _2: string (nullable = true) + |-- _3: struct (nullable = true) + | |-- _1: integer (nullable = false) + | |-- _2: double (nullable = false) + | |-- _3: string (nullable = true) + +scala> df.select(to_csv($"_3")) + +scala> df.select(to_csv($"_3")).printSchema +root + |-- structtocsv(_3): string (nullable = true) + +scala> df.select(to_csv($"_3")).show ++---------------+ +|structtocsv(_3)| ++---------------+ +| 0,3.9,abc| +| 2,0.4,def| ++---------------+ +``` http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala new file mode 100644 index 0000000..abc4c87 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.datasources.csv + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a csv input string to a [[StructType]] with the specified schema. + * + * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+ + */ +case class CsvToStruct(schema: StructType, options: Map[String, String], child: Expression) + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient private val csvOptions = new CSVOptions(options) + @transient private val csvReader = new CsvReader(csvOptions) + @transient private val csvParser = CSVRelation.csvParser(schema, schema.fieldNames, csvOptions) + + private def parse(s: String): InternalRow = { + csvParser(csvReader.parseLine(s), 0).orNull + } + + override def dataType: DataType = schema + + override def nullSafeEval(csv: Any): Any = { + try parse(csv.toString) catch { case _: RuntimeException => null } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} + +/** + * Converts a [[StructType]] to a csv output string. + */ +case class StructToCsv( + options: Map[String, String], + child: Expression) + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val params = new CSVOptions(options) + + @transient + lazy val dataSchema = child.dataType.asInstanceOf[StructType] + + @transient + lazy val writer = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + + override def dataType: DataType = StringType + + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private lazy val valueConverters: Array[ValueConverter] = + dataSchema.map(_.dataType).map(makeConverter).toArray + + private def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + private def rowToString(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = params.nullValue + } + i += 1 + } + values + } + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString + } + + override def nullSafeEval(row: Any): Any = { + writer.writeRow(rowToString(row.asInstanceOf[InternalRow]), false) + UTF8String.fromString(writer.flush()) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index d7fa202..83129a7 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, LogicalPlan} import org.apache.spark.sql.execution.UserProvidedPlanner +import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1483,6 +1484,53 @@ object HivemallOps { }.as("rowid") /** + * Parses a column containing a CSV string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * @group misc + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the csv string + * @param options options to control how the csv is parsed. accepts the same options and the + * csv data source. + */ + def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + CsvToStruct(schema, options, e.expr) + } + + /** + * Parses a column containing a CSV string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * @group misc + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the json string + */ + def from_csv(e: Column, schema: StructType): Column = + from_csv(e, schema, Map.empty[String, String]) + + /** + * Converts a column containing a [[StructType]] into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * @group misc + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + */ + def to_csv(e: Column, options: Map[String, String]): Column = withExpr { + StructToCsv(options, e.expr) + } + + /** + * Converts a column containing a [[StructType]] into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * @group misc + * + * @param e a struct column. + */ + def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String]) + + /** * A convenient function to wrap an expression and produce a Column. */ @inline private def withExpr(expr: Expression): Column = Column(expr) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/cb63532a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index 74b2093..1547227 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -487,6 +487,25 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { assert(errMsg2.getMessage.startsWith("Separator cannot be more than one character:")) } + test("misc - from_csv") { + import hiveContext.implicits._ + val df = Seq("""1,abc""").toDF() + val schema = new StructType().add("a", IntegerType).add("b", StringType) + checkAnswer( + df.select(from_csv($"value", schema)), + Row(Row(1, "abc")) :: Nil) + } + + test("misc - to_csv") { + import hiveContext.implicits._ + val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF() + checkAnswer( + df.select(to_csv($"_3")), + Row("0,3.9,abc") :: + Row("2,0.4,def") :: + Nil) + } + /** * This test fails because; *
