Repository: incubator-hivemall Updated Branches: refs/heads/master 210b7765b -> 33baa2408
Close #61: [HIVEMALL-88][SPARK] Support a function to flatten nested schemas Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/33baa240 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/33baa240 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/33baa240 Branch: refs/heads/master Commit: 33baa2408b77895a4feaaa0f60953055657275d0 Parents: 210b776 Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Thu Mar 9 16:53:00 2017 +0900 Committer: Takeshi Yamamuro <yamam...@apache.org> Committed: Thu Mar 9 16:53:00 2017 +0900 ---------------------------------------------------------------------- .../datasources/csv/csvExpressions.scala | 153 +++++++++++++++++++ .../org/apache/spark/sql/hive/HivemallOps.scala | 48 ++++++ .../spark/sql/hive/HivemallOpsSuite.scala | 19 +++ 3 files changed, 220 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33baa240/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala new file mode 100644 index 0000000..363d432 --- /dev/null +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/execution/datasources/csv/csvExpressions.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.csv + +import java.io.CharArrayWriter + +import jodd.util.CsvUtil + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * Converts a csv input string to a [[StructType]] with the specified schema. + * + * TODO: Move this class into org.apache.spark.sql.catalyst.expressions in Spark-v2.2+ + */ +case class CsvToStruct(schema: StructType, options: Map[String, String], child: Expression) + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient private val csvOptions = new CSVOptions(options) + @transient private val csvReader = new CsvReader(csvOptions) + @transient private val csvParser = CSVRelation.csvParser(schema, schema.fieldNames, csvOptions) + + private def parse(s: String): InternalRow = { + csvParser(csvReader.parseLine(s), 0).orNull + } + + override def dataType: DataType = schema + + override def nullSafeEval(csv: Any): Any = { + try parse(csv.toString) catch { case _: RuntimeException => null } + } + + override def inputTypes: Seq[AbstractDataType] = StringType :: Nil +} + +/** + * Converts a [[StructType]] to a csv output string. + */ +case class StructToCsv( + options: Map[String, String], + child: Expression) + extends UnaryExpression with CodegenFallback with ExpectsInputTypes { + override def nullable: Boolean = true + + @transient + lazy val params = new CSVOptions(options) + + @transient + lazy val dataSchema = child.dataType.asInstanceOf[StructType] + + @transient + lazy val writer = new LineCsvWriter(params, dataSchema.fieldNames.toSeq) + + override def dataType: DataType = StringType + + // A `ValueConverter` is responsible for converting a value of an `InternalRow` to `String`. + // When the value is null, this converter should not be called. + private type ValueConverter = (InternalRow, Int) => String + + // `ValueConverter`s for all values in the fields of the schema + private lazy val valueConverters: Array[ValueConverter] = + dataSchema.map(_.dataType).map(makeConverter).toArray + + private def verifySchema(schema: StructType): Unit = { + def verifyType(dataType: DataType): Unit = dataType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | + DoubleType | BooleanType | _: DecimalType | TimestampType | + DateType | StringType => + + case udt: UserDefinedType[_] => verifyType(udt.sqlType) + + case _ => + throw new UnsupportedOperationException( + s"CSV data source does not support ${dataType.simpleString} data type.") + } + + schema.foreach(field => verifyType(field.dataType)) + } + + override def checkInputDataTypes(): TypeCheckResult = { + if (StructType.acceptsType(child.dataType)) { + try { + verifySchema(child.dataType.asInstanceOf[StructType]) + TypeCheckResult.TypeCheckSuccess + } catch { + case e: UnsupportedOperationException => + TypeCheckResult.TypeCheckFailure(e.getMessage) + } + } else { + TypeCheckResult.TypeCheckFailure( + s"$prettyName requires that the expression is a struct expression.") + } + } + + private def rowToString(row: InternalRow): Seq[String] = { + var i = 0 + val values = new Array[String](row.numFields) + while (i < row.numFields) { + if (!row.isNullAt(i)) { + values(i) = valueConverters(i).apply(row, i) + } else { + values(i) = params.nullValue + } + i += 1 + } + values + } + + private def makeConverter(dataType: DataType): ValueConverter = dataType match { + case DateType => + (row: InternalRow, ordinal: Int) => + params.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal))) + + case TimestampType => + (row: InternalRow, ordinal: Int) => + params.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal))) + + case udt: UserDefinedType[_] => makeConverter(udt.sqlType) + + case dt: DataType => + (row: InternalRow, ordinal: Int) => + row.get(ordinal, dt).toString + } + + override def nullSafeEval(row: Any): Any = { + writer.writeRow(rowToString(row.asInstanceOf[InternalRow]), false) + UTF8String.fromString(writer.flush()) + } + + override def inputTypes: Seq[AbstractDataType] = StructType :: Nil +} http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33baa240/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 6883ac1..f583cca 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.{Generate, JoinTopK, LogicalPlan} import org.apache.spark.sql.execution.UserProvidedPlanner +import org.apache.spark.sql.execution.datasources.csv.{CsvToStruct, StructToCsv} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1423,6 +1424,53 @@ object HivemallOps { }.as("rowid") /** + * Parses a column containing a CSV string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * @group misc + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the csv string + * @param options options to control how the csv is parsed. accepts the same options and the + * csv data source. + */ + def from_csv(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + CsvToStruct(schema, options, e.expr) + } + + /** + * Parses a column containing a CSV string into a [[StructType]] with the specified schema. + * Returns `null`, in the case of an unparseable string. + * @group misc + * + * @param e a string column containing CSV data. + * @param schema the schema to use when parsing the json string + */ + def from_csv(e: Column, schema: StructType): Column = + from_csv(e, schema, Map.empty[String, String]) + + /** + * Converts a column containing a `StructType` into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * @group misc + * + * @param e a struct column. + * @param options options to control how the struct column is converted into a json string. + * accepts the same options and the json data source. + */ + def to_csv(e: Column, options: Map[String, String]): Column = withExpr { + StructToCsv(options, e.expr) + } + + /** + * Converts a column containing a `StructType` into a CSV string with the specified schema. + * Throws an exception, in the case of an unsupported type. + * @group misc + * + * @param e a struct column. + */ + def to_csv(e: Column): Column = to_csv(e, Map.empty[String, String]) + + /** * A convenient function to wrap an expression and produce a Column. */ @inline private def withExpr(expr: Expression): Column = Column(expr) http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/33baa240/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index ed56bc3..d595df2 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -461,6 +461,25 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { } } + test("misc - from_csv") { + import hiveContext.implicits._ + val df = Seq("""1,abc""").toDF() + val schema = new StructType().add("a", IntegerType).add("b", StringType) + checkAnswer( + df.select(from_csv($"value", schema)), + Row(Row(1, "abc")) :: Nil) + } + + test("misc - to_csv") { + import hiveContext.implicits._ + val df = Seq((1, "a", (0, 3.9, "abc")), (8, "c", (2, 0.4, "def"))).toDF() + checkAnswer( + df.select(to_csv($"_3")), + Row("0,3.9,abc") :: + Row("2,0.4,def") :: + Nil) + } + /** * This test fails because; *