Repository: incubator-hivemall
Updated Branches:
  refs/heads/master 210b7765b -> f7fc3041f


Close #61: [HIVEMALL-88][SPARK] Support a function to flatten nested schemas


Project: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-hivemall/commit/f7fc3041
Tree: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/tree/f7fc3041
Diff: http://git-wip-us.apache.org/repos/asf/incubator-hivemall/diff/f7fc3041

Branch: refs/heads/master
Commit: f7fc3041fba258a578bf0bf4bd78d5422718777c
Parents: 210b776
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Thu Mar 9 17:00:42 2017 +0900
Committer: Takeshi Yamamuro <yamam...@apache.org>
Committed: Thu Mar 9 17:00:42 2017 +0900

----------------------------------------------------------------------
 docs/gitbook/SUMMARY.md                         |  1 +
 docs/gitbook/spark/misc/functions.md            | 47 +++++++++++++++
 .../org/apache/spark/sql/hive/HivemallOps.scala | 60 ++++++++++++++++++++
 .../spark/sql/hive/HivemallOpsSuite.scala       | 26 +++++++++
 4 files changed, 134 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/docs/gitbook/SUMMARY.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/SUMMARY.md b/docs/gitbook/SUMMARY.md
index 6840cac..4c6ed1b 100644
--- a/docs/gitbook/SUMMARY.md
+++ b/docs/gitbook/SUMMARY.md
@@ -163,6 +163,7 @@
 
 * [Generic features](spark/misc/misc.md)
     * [Top-k Join processing](spark/misc/topk_join.md)
+    * [Other utility functions](spark/misc/functions.md)
 
 ## Part X - External References
 

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/docs/gitbook/spark/misc/functions.md
----------------------------------------------------------------------
diff --git a/docs/gitbook/spark/misc/functions.md 
b/docs/gitbook/spark/misc/functions.md
new file mode 100644
index 0000000..23763dd
--- /dev/null
+++ b/docs/gitbook/spark/misc/functions.md
@@ -0,0 +1,47 @@
+<!--
+  Licensed to the Apache Software Foundation (ASF) under one
+  or more contributor license agreements.  See the NOTICE file
+  distributed with this work for additional information
+  regarding copyright ownership.  The ASF licenses this file
+  to you under the Apache License, Version 2.0 (the
+  "License"); you may not use this file except in compliance
+  with the License.  You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+  Unless required by applicable law or agreed to in writing,
+  software distributed under the License is distributed on an
+  "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+  KIND, either express or implied.  See the License for the
+  specific language governing permissions and limitations
+  under the License.
+-->
+
+`df.flatten()` flattens a nested schema of `df` into a flat one.
+
+# Usage
+
+```scala
+scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF()
+scala> df.printSchema
+root
+ |-- _1: integer (nullable = false)
+ |-- _2: struct (nullable = true)
+ |    |-- _1: integer (nullable = false)
+ |    |-- _2: struct (nullable = true)
+ |    |    |-- _1: double (nullable = false)
+ |    |    |-- _2: string (nullable = true)
+ |-- _3: struct (nullable = true)
+ |    |-- _1: integer (nullable = false)
+ |    |-- _2: double (nullable = false)
+
+scala> df.flatten(separator = "$").printSchema
+root
+ |-- _1: integer (nullable = false)
+ |-- _2$_1: integer (nullable = true)
+ |-- _2$_2$_1: double (nullable = true)
+ |-- _2$_2$_2: string (nullable = true)
+ |-- _3$_1: integer (nullable = true)
+ |-- _3$_2: double (nullable = true)
+```
+

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala 
b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
index 6883ac1..d7fa202 100644
--- a/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
+++ b/spark/spark-2.1/src/main/scala/org/apache/spark/sql/hive/HivemallOps.scala
@@ -805,6 +805,66 @@ final class HivemallOps(df: DataFrame) extends Logging {
     JoinTopK(kInt, df.logicalPlan, right.logicalPlan, Inner, 
Option(joinExprs.expr))(score.named)
   }
 
+  private def doFlatten(schema: StructType, separator: Char, prefixParts: 
Seq[String] = Seq.empty)
+    : Seq[Column] = {
+    schema.fields.flatMap { f =>
+      val colNameParts = prefixParts :+ f.name
+      f.dataType match {
+        case st: StructType =>
+          doFlatten(st, separator, colNameParts)
+        case _ =>
+          
col(colNameParts.mkString(".")).as(colNameParts.mkString(separator.toString)) 
:: Nil
+      }
+    }
+  }
+
+  // Converts string representation of a character to actual character
+  @throws[IllegalArgumentException]
+  private def toChar(str: String): Char = {
+    if (str.length == 1) {
+      str.charAt(0) match {
+        case '$' | '_' | '.' => str.charAt(0)
+        case _ => throw new IllegalArgumentException(
+          "Must use '$', '_', or '.' for separator, but got " + str)
+      }
+    } else {
+      throw new IllegalArgumentException(
+        s"Separator cannot be more than one character: $str")
+    }
+  }
+
+  /**
+   * Flattens a nested schema into a flat one.
+   * @group misc
+   *
+   * For example:
+   * {{{
+   *  scala> val df = Seq((0, (1, (3.0, "a")), (5, 0.9))).toDF()
+   *  scala> df.printSchema
+   *  root
+   *   |-- _1: integer (nullable = false)
+   *   |-- _2: struct (nullable = true)
+   *   |    |-- _1: integer (nullable = false)
+   *   |    |-- _2: struct (nullable = true)
+   *   |    |    |-- _1: double (nullable = false)
+   *   |    |    |-- _2: string (nullable = true)
+   *   |-- _3: struct (nullable = true)
+   *   |    |-- _1: integer (nullable = false)
+   *   |    |-- _2: double (nullable = false)
+   *
+   *  scala> df.flatten(separator = "$").printSchema
+   *  root
+   *   |-- _1: integer (nullable = false)
+   *   |-- _2$_1: integer (nullable = true)
+   *   |-- _2$_2$_1: double (nullable = true)
+   *   |-- _2$_2$_2: string (nullable = true)
+   *   |-- _3$_1: integer (nullable = true)
+   *   |-- _3$_2: double (nullable = true)
+   * }}}
+   */
+  def flatten(separator: String = "$"): DataFrame =
+    df.select(doFlatten(df.schema, toChar(separator)): _*)
+
   /**
    * @see [[hivemall.dataset.LogisticRegressionDataGeneratorUDTF]]
    * @group misc

http://git-wip-us.apache.org/repos/asf/incubator-hivemall/blob/f7fc3041/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
index ed56bc3..74b2093 100644
--- 
a/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
+++ 
b/spark/spark-2.1/src/test/scala/org/apache/spark/sql/hive/HivemallOpsSuite.scala
@@ -461,6 +461,32 @@ final class HivemallOpsWithFeatureSuite extends 
HivemallFeatureQueryTest {
     }
   }
 
+  test("misc - flatten") {
+    import hiveContext.implicits._
+    val df = Seq((0, (1, "a", (3.0, "b")), (5, 0.9, "c", "d"), 9)).toDF()
+    assert(df.flatten().schema === StructType(
+      StructField("_1", IntegerType, nullable = false) ::
+      StructField("_2$_1", IntegerType, nullable = true) ::
+      StructField("_2$_2", StringType, nullable = true) ::
+      StructField("_2$_3$_1", DoubleType, nullable = true) ::
+      StructField("_2$_3$_2", StringType, nullable = true) ::
+      StructField("_3$_1", IntegerType, nullable = true) ::
+      StructField("_3$_2", DoubleType, nullable = true) ::
+      StructField("_3$_3", StringType, nullable = true) ::
+      StructField("_3$_4", StringType, nullable = true) ::
+      StructField("_4", IntegerType, nullable = false) ::
+      Nil
+    ))
+    checkAnswer(df.flatten("$").select("_2$_1"), Row(1))
+    checkAnswer(df.flatten("_").select("_2__1"), Row(1))
+    checkAnswer(df.flatten(".").select("`_2._1`"), Row(1))
+
+    val errMsg1 = intercept[IllegalArgumentException] { df.flatten("\t") }
+    assert(errMsg1.getMessage.startsWith("Must use '$', '_', or '.' for 
separator, but got"))
+    val errMsg2 = intercept[IllegalArgumentException] { df.flatten("12") }
+    assert(errMsg2.getMessage.startsWith("Separator cannot be more than one 
character:"))
+  }
+
   /**
    * This test fails because;
    *

Reply via email to