Repository: spark
Updated Branches:
refs/heads/branch-1.4 a238c23b0 -> 778a0548c
[SPARK-7548] [SQL] Add explode function for DataFrames
Add an `explode` function for dataframes and modify the analyzer so that single
table generating functions can be present in a select clause along with other
expressions. There are currently the following restrictions:
- only top level TGFs are allowed (i.e. no `select(explode('list) + 1)`)
- only one may be present in a single select to avoid potentially confusing
implicit Cartesian products.
TODO:
- [ ] Python
Author: Michael Armbrust <[email protected]>
Closes #6107 from marmbrus/explodeFunction and squashes the following commits:
7ee2c87 [Michael Armbrust] whitespace
6f80ba3 [Michael Armbrust] Update dataframe.py
c176c89 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into
explodeFunction
81b5da3 [Michael Armbrust] style
d3faa05 [Michael Armbrust] fix self join case
f9e1e3e [Michael Armbrust] fix python, add since
4f0d0a9 [Michael Armbrust] Merge remote-tracking branch 'origin/master' into
explodeFunction
e710fe4 [Michael Armbrust] add java and python
52ca0dc [Michael Armbrust] [SPARK-7548][SQL] Add explode function for
dataframes.
(cherry picked from commit 6d0633e3ec9518278fcc7eba58549d4ad3d5813f)
Signed-off-by: Michael Armbrust <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/778a0548
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/778a0548
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/778a0548
Branch: refs/heads/branch-1.4
Commit: 778a0548cca35496f6546c3710270201283b749d
Parents: a238c23
Author: Michael Armbrust <[email protected]>
Authored: Thu May 14 19:49:44 2015 -0700
Committer: Michael Armbrust <[email protected]>
Committed: Thu May 14 19:51:00 2015 -0700
----------------------------------------------------------------------
python/pyspark/sql/dataframe.py | 12 +-
python/pyspark/sql/functions.py | 20 ++++
python/pyspark/sql/tests.py | 15 +++
.../spark/sql/catalyst/analysis/Analyzer.scala | 117 ++++++++++++-------
.../catalyst/plans/logical/basicOperators.scala | 3 +
.../sql/catalyst/analysis/AnalysisSuite.scala | 10 +-
.../scala/org/apache/spark/sql/Column.scala | 27 ++++-
.../scala/org/apache/spark/sql/DataFrame.scala | 5 +-
.../scala/org/apache/spark/sql/functions.scala | 5 +
.../spark/sql/ColumnExpressionSuite.scala | 60 ++++++++++
10 files changed, 223 insertions(+), 51 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 82cb1c2..2ed95ac 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1511,13 +1511,19 @@ class Column(object):
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not
null.")
- def alias(self, alias):
- """Return a alias for this column
+ def alias(self, *alias):
+ """Returns this column aliased with a new name or names (in the case
of expressions that
+ return more than one column, such as explode).
>>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
"""
- return Column(getattr(self._jc, "as")(alias))
+
+ if len(alias) == 1:
+ return Column(getattr(self._jc, "as")(alias[0]))
+ else:
+ sc = SparkContext._active_spark_context
+ return Column(getattr(self._jc, "as")(_to_seq(sc, list(alias))))
@ignore_unicode_prefix
def cast(self, dataType):
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index d91265e..6cd6974 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -169,6 +169,26 @@ def approxCountDistinct(col, rsd=None):
return Column(jc)
+def explode(col):
+ """Returns a new row for each element in the given array or map.
+
+ >>> from pyspark.sql import Row
+ >>> eDF = sqlContext.createDataFrame([Row(a=1, intlist=[1,2,3],
mapfield={"a": "b"})])
+ >>> eDF.select(explode(eDF.intlist).alias("anInt")).collect()
+ [Row(anInt=1), Row(anInt=2), Row(anInt=3)]
+
+ >>> eDF.select(explode(eDF.mapfield).alias("key", "value")).show()
+ +---+-----+
+ |key|value|
+ +---+-----+
+ | a| b|
+ +---+-----+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.explode(_to_java_column(col))
+ return Column(jc)
+
+
def coalesce(*cols):
"""Returns the first column that is not null.
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1922d03..d37c5db 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -117,6 +117,21 @@ class SQLTests(ReusedPySparkTestCase):
ReusedPySparkTestCase.tearDownClass()
shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+ def test_explode(self):
+ from pyspark.sql.functions import explode
+ d = [Row(a=1, intlist=[1, 2, 3], mapfield={"a": "b"})]
+ rdd = self.sc.parallelize(d)
+ data = self.sqlCtx.createDataFrame(rdd)
+
+ result =
data.select(explode(data.intlist).alias("a")).select("a").collect()
+ self.assertEqual(result[0][0], 1)
+ self.assertEqual(result[1][0], 2)
+ self.assertEqual(result[2][0], 3)
+
+ result = data.select(explode(data.mapfield).alias("a",
"b")).select("a", "b").collect()
+ self.assertEqual(result[0][0], "a")
+ self.assertEqual(result[0][1], "b")
+
def test_udf_with_callable(self):
d = [Row(number=i, squared=i**2) for i in range(10)]
rdd = self.sc.parallelize(d)
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 4baeeb5..0b6e1d4 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -73,7 +73,6 @@ class Analyzer(
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
- ImplicitGenerate ::
ResolveFunctions ::
ExtractWindowExpressions ::
GlobalAggregates ::
@@ -323,6 +322,11 @@ class Analyzer(
if
findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions =
newAliases(aggregateExpressions)))
+ case oldVersion: Generate
+ if
oldVersion.generatedSet.intersect(conflictingAttributes).nonEmpty =>
+ val newOutput = oldVersion.generatorOutput.map(_.newInstance())
+ (oldVersion, oldVersion.copy(generatorOutput = newOutput))
+
case oldVersion @ Window(_, windowExpressions, _, child)
if
AttributeSet(windowExpressions.map(_.toAttribute)).intersect(conflictingAttributes)
.nonEmpty =>
@@ -521,66 +525,89 @@ class Analyzer(
}
/**
- * When a SELECT clause has only a single expression and that expression is a
- * [[catalyst.expressions.Generator Generator]] we convert the
- * [[catalyst.plans.logical.Project Project]] to a
[[catalyst.plans.logical.Generate Generate]].
+ * Rewrites table generating expressions that either need one or more of the
following in order
+ * to be resolved:
+ * - concrete attribute references for their output.
+ * - to be relocated from a SELECT clause (i.e. from a [[Project]]) into a
[[Generate]]).
+ *
+ * Names for the output [[Attributes]] are extracted from [[Alias]] or
[[MultiAlias]] expressions
+ * that wrap the [[Generator]]. If more than one [[Generator]] is found in a
Project, an
+ * [[AnalysisException]] is throw.
*/
- object ImplicitGenerate extends Rule[LogicalPlan] {
+ object ResolveGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Project(Seq(Alias(g: Generator, name)), child) =>
- Generate(g, join = false, outer = false,
- qualifier = None, UnresolvedAttribute(name) :: Nil, child)
- case Project(Seq(MultiAlias(g: Generator, names)), child) =>
- Generate(g, join = false, outer = false,
- qualifier = None, names.map(UnresolvedAttribute(_)), child)
+ case p: Generate if !p.child.resolved || !p.generator.resolved => p
+ case g: Generate if g.resolved == false =>
+ g.copy(
+ generatorOutput = makeGeneratorOutput(g.generator,
g.generatorOutput.map(_.name)))
+
+ case p @ Project(projectList, child) =>
+ // Holds the resolved generator, if one exists in the project list.
+ var resolvedGenerator: Generate = null
+
+ val newProjectList = projectList.flatMap {
+ case AliasedGenerator(generator, names) if
generator.childrenResolved =>
+ if (resolvedGenerator != null) {
+ failAnalysis(
+ s"Only one generator allowed per select but
${resolvedGenerator.nodeName} and " +
+ s"and ${generator.nodeName} found.")
+ }
+
+ resolvedGenerator =
+ Generate(
+ generator,
+ join = projectList.size > 1, // Only join if there are other
expressions in SELECT.
+ outer = false,
+ qualifier = None,
+ generatorOutput = makeGeneratorOutput(generator, names),
+ child)
+
+ resolvedGenerator.generatorOutput
+ case other => other :: Nil
+ }
+
+ if (resolvedGenerator != null) {
+ Project(newProjectList, resolvedGenerator)
+ } else {
+ p
+ }
}
- }
- /**
- * Resolve the Generate, if the output names specified, we will take them,
otherwise
- * we will try to provide the default names, which follow the same rule with
Hive.
- */
- object ResolveGenerate extends Rule[LogicalPlan] {
- // Construct the output attributes for the generator,
- // The output attribute names can be either specified or
- // auto generated.
+ /** Extracts a [[Generator]] expression and any names assigned by aliases
to their output. */
+ private object AliasedGenerator {
+ def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
+ case Alias(g: Generator, name) => Some((g, name :: Nil))
+ case MultiAlias(g: Generator, names) => Some(g, names)
+ case _ => None
+ }
+ }
+
+ /**
+ * Construct the output attributes for a [[Generator]], given a list of
names. If the list of
+ * names is empty names are assigned by ordinal (i.e., _c0, _c1, ...) to
match Hive's defaults.
+ */
private def makeGeneratorOutput(
generator: Generator,
- generatorOutput: Seq[Attribute]): Seq[Attribute] = {
+ names: Seq[String]): Seq[Attribute] = {
val elementTypes = generator.elementTypes
- if (generatorOutput.length == elementTypes.length) {
- generatorOutput.zip(elementTypes).map {
- case (a, (t, nullable)) if !a.resolved =>
- AttributeReference(a.name, t, nullable)()
- case (a, _) => a
+ if (names.length == elementTypes.length) {
+ names.zip(elementTypes).map {
+ case (name, (t, nullable)) =>
+ AttributeReference(name, t, nullable)()
}
- } else if (generatorOutput.length == 0) {
+ } else if (names.isEmpty) {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
} else {
- throw new AnalysisException(
- s"""
- |The number of aliases supplied in the AS clause does not match
- |the number of columns output by the UDTF expected
- |${elementTypes.size} aliases but got ${generatorOutput.size}
- """.stripMargin)
+ failAnalysis(
+ "The number of aliases supplied in the AS clause does not match the
number of columns " +
+ s"output by the UDTF expected ${elementTypes.size} aliases but got "
+
+ s"${names.mkString(",")} ")
}
}
-
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p: Generate if !p.child.resolved || !p.generator.resolved => p
- case p: Generate if p.resolved == false =>
- // if the generator output names are not specified, we will use the
default ones.
- Generate(
- p.generator,
- join = p.join,
- outer = p.outer,
- p.qualifier,
- makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
- }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 0f349f9..01f4b6e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -59,6 +59,9 @@ case class Generate(
child: LogicalPlan)
extends UnaryNode {
+ /** The set of all attributes produced by this node. */
+ def generatedSet: AttributeSet = AttributeSet(generatorOutput)
+
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 6f2f355..e1d6ac4 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -72,6 +72,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
StructField("cField", StringType) :: Nil
))())
+ val listRelation = LocalRelation(
+ AttributeReference("list", ArrayType(IntegerType))())
+
before {
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
@@ -159,11 +162,16 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
}
}
- errorMessages.foreach(m => assert(error.getMessage contains m))
+ errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains
m.toLowerCase))
}
}
errorTest(
+ "too many generators",
+ listRelation.select(Explode('list).as('a), Explode('list).as('b)),
+ "only one generator" :: "explode" :: Nil)
+
+ errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 8bf1320..dc0aeea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -18,12 +18,13 @@
package org.apache.spark.sql
import scala.language.implicitConversions
+import scala.collection.JavaConversions._
import org.apache.spark.annotation.Experimental
import org.apache.spark.Logging
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute,
UnresolvedStar, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias,
UnresolvedAttribute, UnresolvedStar, UnresolvedExtractValue}
import org.apache.spark.sql.types._
@@ -728,6 +729,30 @@ class Column(protected[sql] val expr: Expression) extends
Logging {
def as(alias: String): Column = Alias(expr, alias)()
/**
+ * (Scala-specific) Assigns the given aliases to the results of a table
generating function.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select(explode($"myMap").as("key" :: "value" :: Nil))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def as(aliases: Seq[String]): Column = MultiAlias(expr, aliases)
+
+ /**
+ * Assigns the given aliases to the results of a table generating function.
+ * {{{
+ * // Renames colA to colB in select output.
+ * df.select(explode($"myMap").as("key" :: "value" :: Nil))
+ * }}}
+ *
+ * @group expr_ops
+ * @since 1.4.0
+ */
+ def as(aliases: Array[String]): Column = MultiAlias(expr, aliases)
+
+ /**
* Gives the column an alias.
* {{{
* // Renames colA to colB in select output.
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 4fd5105..2e20c3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,7 +34,7 @@ import org.apache.spark.annotation.{DeveloperApi,
Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar,
UnresolvedAttribute, UnresolvedRelation}
+import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar,
UnresolvedAttribute, UnresolvedRelation}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, _}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType}
@@ -593,6 +593,9 @@ class DataFrame private[sql](
def select(cols: Column*): DataFrame = {
val namedExpressions = cols.map {
case Column(expr: NamedExpression) => expr
+ // Leave an unaliased explode with an empty list of names since the
analzyer will generate the
+ // correct defaults after the nested expression's type has been resolved.
+ case Column(explode: Explode) => MultiAlias(explode, Nil)
case Column(expr: Expression) => Alias(expr, expr.prettyString)()
}
// When user continuously call `select`, speed up analysis by collapsing
`Project`
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 4404ad8..6640631 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -364,6 +364,11 @@ object functions {
def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
/**
+ * Creates a new row for each element in the given array or map column.
+ */
+ def explode(e: Column): Column = Explode(e.expr)
+
+ /**
* Converts a string exprsesion to lower case.
*
* @group normal_funcs
http://git-wip-us.apache.org/repos/asf/spark/blob/778a0548/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 269e185..9bdf201 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -27,6 +27,66 @@ import org.apache.spark.sql.types._
class ColumnExpressionSuite extends QueryTest {
import org.apache.spark.sql.TestData._
+ test("single explode") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(explode('intList)),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+ }
+
+ test("explode and other columns") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select($"a", explode('intList)),
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(1, 3) :: Nil)
+
+ checkAnswer(
+ df.select($"*", explode('intList)),
+ Row(1, Seq(1,2,3), 1) ::
+ Row(1, Seq(1,2,3), 2) ::
+ Row(1, Seq(1,2,3), 3) :: Nil)
+ }
+
+ test("aliased explode") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select('int),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select(sum('int)),
+ Row(6) :: Nil)
+ }
+
+ test("explode on map") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map)),
+ Row("a", "b"))
+ }
+
+ test("explode on map with aliases") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1",
"value1"),
+ Row("a", "b"))
+ }
+
+ test("self join explode") {
+ val df = Seq((1, Seq(1,2,3))).toDF("a", "intList")
+ val exploded = df.select(explode('intList).as('i))
+
+ checkAnswer(
+ exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
+ Row(3) :: Nil)
+ }
+
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]