Repository: incubator-hivemall Updated Branches: refs/heads/master 210b7765b -> f7fc3041f
Close #61: [HIVEMALL-88][SPARK] Support a function to flatten nested schemas Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f7fc3041 Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f7fc3041 Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f7fc3041 Branch: refs/heads/master Commit: f7fc3041fba258a578bf0bf4bd78d5422718777c Parents: 210b776 Author: Takeshi Yamamuro <yamam...@apache.org> Authored: Thu Mar 9 17:00:42 2017 +0900 Committer: Takeshi Yamamuro <yamam...@apache.org> Committed: Thu Mar 9 17:00:42 2017 +0900 ---------------------------------------------------------------------- docs/gitbook/SUMMARY.md | 1 + docs/gitbook/spark/misc/functions.md | 47 +++++++++++++++ .../org/apache/spark/sql/hive/HivemallOps.scala | 60 ++++++++++++++++++++ .../spark/sql/hive/HivemallOpsSuite.scala | 26 +++++++++ 4 files changed, 134 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/docs/gitbook/SUMMARY.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md index 6840cac..4c6ed1b 100644 --- a/docs/gitbook/SUMMARY.md +++ b/docs/gitbook/SUMMARY.md @@ -163,6 +163,7 @@ * [Generic features](spark/misc/misc.md) * [Top-k Join processing](spark/misc/topk_join.md) + * [Other utility functions](spark/misc/functions.md) ## Part X - External References http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/docs/gitbook/spark/misc/functions.md ---------------------------------------------------------------------- diff --git a/docs/gitbook/spark/misc/functions.md b/docs/gitbook/spark/misc/functions.md new file mode 100644 index 0000000..23763dd --- /dev/null +++ b/docs/gitbook/spark/misc/functions.md @@ -0,0 +1,47 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> + +`df.flatten()` flattens a nested schema of `df` into a flat one. + +# Usage + +```scala +scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF() +scala> df.printSchema +root + |-- _1: integer (nullable = false) + |-- _2: struct (nullable = true) + | |-- _1: integer (nullable = false) + | |-- _2: struct (nullable = true) + | | |-- _1: double (nullable = false) + | | |-- _2: string (nullable = true) + |-- _3: struct (nullable = true) + | |-- _1: integer (nullable = false) + | |-- _2: double (nullable = false) + +scala> df.flatten(separator = "$").printSchema +root + |-- _1: integer (nullable = false) + |-- _2$_1: integer (nullable = true) + |-- _2$_2$_1: double (nullable = true) + |-- _2$_2$_2: string (nullable = true) + |-- _3$_1: integer (nullable = true) + |-- _3$_2: double (nullable = true) +``` + http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala index 6883ac1..d7fa202 100644 --- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala +++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala @@ -805,6 +805,66 @@ final class HivemallOps(df: DataFrame) extends Logging { JoinTopK(kInt, df.logicalPlan, right.logicalPlan, Inner, Option(joinExprs.expr))(score.named) } + private def doFlatten(schema: StructType, separator: Char, prefixParts: Seq[String] = Seq.empty) + : Seq[Column] = { + schema.fields.flatMap { f => + val colNameParts = prefixParts :+ f.name + f.dataType match { + case st: StructType => + doFlatten(st, separator, colNameParts) + case _ => + col(colNameParts.mkString(".")).as(colNameParts.mkString(separator.toString)) :: Nil + } + } + } + + // Converts string representation of a character to actual character + @throws[IllegalArgumentException] + private def toChar(str: String): Char = { + if (str.length == 1) { + str.charAt(0) match { + case '$' | '_' | '.' => str.charAt(0) + case _ => throw new IllegalArgumentException( + "Must use '$', '_', or '.' for separator, but got " + str) + } + } else { + throw new IllegalArgumentException( + s"Separator cannot be more than one character: $str") + } + } + + /** + * Flattens a nested schema into a flat one. + * @group misc + * + * For example: + * {{{ + * scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF() + * scala> df.printSchema + * root + * |-- _1: integer (nullable = false) + * |-- _2: struct (nullable = true) + * | |-- _1: integer (nullable = false) + * | |-- _2: struct (nullable = true) + * | | |-- _1: double (nullable = false) + * | | |-- _2: string (nullable = true) + * |-- _3: struct (nullable = true) + * | |-- _1: integer (nullable = false) + * | |-- _2: double (nullable = false) + * + * scala> df.flatten(separator = "$").printSchema + * root + * |-- _1: integer (nullable = false) + * |-- _2$_1: integer (nullable = true) + * |-- _2$_2$_1: double (nullable = true) + * |-- _2$_2$_2: string (nullable = true) + * |-- _3$_1: integer (nullable = true) + * |-- _3$_2: double (nullable = true) + * }}} + */ + def flatten(separator: String = "$"): DataFrame = + df.select(doFlatten(df.schema, toChar(separator)): _*) + /** * @see [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]] * @group misc http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala ---------------------------------------------------------------------- diff --git a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala index ed56bc3..74b2093 100644 --- a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala +++ b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala @@ -461,6 +461,32 @@ final class HivemallOpsWithFeatureSuite extends HivemallFeatureQueryTest { } } + test("misc - flatten") { + import hiveContext.implicits._ + val df = Seq((0, (1, "a", (3.0, "b")), (5, 0.9, "c", "d"), 9)).toDF() + assert(df.flatten().schema === StructType( + StructField("_1", IntegerType, nullable = false) :: + StructField("_2$_1", IntegerType, nullable = true) :: + StructField("_2$_2", StringType, nullable = true) :: + StructField("_2$_3$_1", DoubleType, nullable = true) :: + StructField("_2$_3$_2", StringType, nullable = true) :: + StructField("_3$_1", IntegerType, nullable = true) :: + StructField("_3$_2", DoubleType, nullable = true) :: + StructField("_3$_3", StringType, nullable = true) :: + StructField("_3$_4", StringType, nullable = true) :: + StructField("_4", IntegerType, nullable = false) :: + Nil + )) + checkAnswer(df.flatten("$").select("_2$_1"), Row(1)) + checkAnswer(df.flatten("_").select("_2__1"), Row(1)) + checkAnswer(df.flatten(".").select("`_2._1`"), Row(1)) + + val errMsg1 = intercept[IllegalArgumentException] { df.flatten("\t") } + assert(errMsg1.getMessage.startsWith("Must use '$', '_', or '.' for separator, but got")) + val errMsg2 = intercept[IllegalArgumentException] { df.flatten("12") } + assert(errMsg2.getMessage.startsWith("Separator cannot be more than one character:")) + } + /** * This test fails because; *