Repository: spark
Updated Branches:
refs/heads/master 1a21be15f -> b8ff6888e
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api.
Following the lead of cube and rollup this adds a Pivot operator that is
translated into an Aggregate by the analyzer.
Currently the syntax is like:
~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"),
sum($"earnings"))~~
~~Would we be interested in the following syntax also/alternatively? and~~
courseSales.groupBy($"year").pivot($"course", "dotNET",
"Java").agg(sum($"earnings"))
//or
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings"))
Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add
it there, right?
~~Also what would be the suggested Java friendly method signature for this?~~
Author: Andrew Ray <[email protected]>
Closes #7841 from aray/sql-pivot.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b8ff6888
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b8ff6888
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b8ff6888
Branch: refs/heads/master
Commit: b8ff6888e76b437287d7d6bf2d4b9c759710a195
Parents: 1a21be1
Author: Andrew Ray <[email protected]>
Authored: Wed Nov 11 16:23:24 2015 -0800
Committer: Yin Huai <[email protected]>
Committed: Wed Nov 11 16:23:24 2015 -0800
----------------------------------------------------------------------
.../spark/sql/catalyst/analysis/Analyzer.scala | 42 ++++++++
.../catalyst/plans/logical/basicOperators.scala | 14 +++
.../org/apache/spark/sql/GroupedData.scala | 103 +++++++++++++++++--
.../scala/org/apache/spark/sql/SQLConf.scala | 7 ++
.../apache/spark/sql/DataFramePivotSuite.scala | 87 ++++++++++++++++
.../org/apache/spark/sql/test/SQLTestData.scala | 12 +++
6 files changed, 255 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/b8ff6888/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a9cd9a7..2f4670b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -72,6 +72,7 @@ class Analyzer(
ResolveRelations ::
ResolveReferences ::
ResolveGroupingAnalytics ::
+ ResolvePivot ::
ResolveSortReferences ::
ResolveGenerate ::
ResolveFunctions ::
@@ -166,6 +167,10 @@ class Analyzer(
case g: GroupingAnalytics if g.child.resolved &&
hasUnresolvedAlias(g.aggregations) =>
g.withNewAggs(assignAliases(g.aggregations))
+ case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
+ if child.resolved && hasUnresolvedAlias(groupByExprs) =>
+ Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues,
aggregates, child)
+
case Project(projectList, child) if child.resolved &&
hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
}
@@ -248,6 +253,43 @@ class Analyzer(
}
}
+ object ResolvePivot extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p: Pivot if !p.childrenResolved => p
+ case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
+ val singleAgg = aggregates.size == 1
+ val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap {
value =>
+ def ifExpr(expr: Expression) = {
+ If(EqualTo(pivotColumn, value), expr, Literal(null))
+ }
+ aggregates.map { aggregate =>
+ val filteredAggregate = aggregate.transformDown {
+ // Assumption is the aggregate function ignores nulls. This is
true for all current
+ // AggregateFunction's with the exception of First and Last in
their default mode
+ // (which we handle) and possibly some Hive UDAF's.
+ case First(expr, _) =>
+ First(ifExpr(expr), Literal(true))
+ case Last(expr, _) =>
+ Last(ifExpr(expr), Literal(true))
+ case a: AggregateFunction =>
+ a.withNewChildren(a.children.map(ifExpr))
+ }
+ if (filteredAggregate.fastEquals(aggregate)) {
+ throw new AnalysisException(
+ s"Aggregate expression required for pivot, found '$aggregate'")
+ }
+ val name = if (singleAgg) value.toString else value + "_" +
aggregate.prettyString
+ Alias(filteredAggregate, name)()
+ }
+ }
+ val newGroupByExprs = groupByExprs.map {
+ case UnresolvedAlias(e) => e
+ case e => e
+ }
+ Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
+ }
+ }
+
/**
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
http://git-wip-us.apache.org/repos/asf/spark/blob/b8ff6888/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 597f03e..32b09b5 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -386,6 +386,20 @@ case class Rollup(
this.copy(aggregations = aggs)
}
+case class Pivot(
+ groupByExprs: Seq[NamedExpression],
+ pivotColumn: Expression,
+ pivotValues: Seq[Literal],
+ aggregates: Seq[Expression],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++
aggregates match {
+ case agg :: Nil => pivotValues.map(value =>
AttributeReference(value.toString, agg.dataType)())
+ case _ => pivotValues.flatMap{ value =>
+ aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString,
agg.dataType)())
+ }
+ }
+}
+
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
http://git-wip-us.apache.org/repos/asf/spark/blob/b8ff6888/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 5babf2c..63dd7fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction,
UnresolvedAlias, UnresolvedAttribute, Star}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
-import org.apache.spark.sql.types.NumericType
+import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube,
Aggregate}
+import org.apache.spark.sql.types.{StringType, NumericType}
/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
aggExprs
}
- val aliasedAgg = aggregates.map {
- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve
UnresolvedAttribute, we
- // will remove intermediate Alias for ExtractValue chain, and we need to
alias it again to
- // make it a NamedExpression.
- case u: UnresolvedAttribute => UnresolvedAlias(u)
- case expr: NamedExpression => expr
- case expr: Expression => Alias(expr, expr.prettyString)()
- }
+ val aliasedAgg = aggregates.map(alias)
+
groupType match {
case GroupedData.GroupByType =>
DataFrame(
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
+ case GroupedData.PivotType(pivotCol, values) =>
+ val aliasedGrps = groupingExprs.map(alias)
+ DataFrame(
+ df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs,
df.logicalPlan))
}
}
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve
UnresolvedAttribute, we
+ // will remove intermediate Alias for ExtractValue chain, and we need to
alias it again to
+ // make it a NamedExpression.
+ private[this] def alias(expr: Expression): NamedExpression = expr match {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression
=> AggregateFunction)
: DataFrame = {
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
def sum(colNames: String*): DataFrame = {
aggregateNumericColumns(colNames : _*)(Sum)
}
+
+ /**
+ * (Scala-specific) Pivots a column of the current [[DataFrame]] and
preform the specified
+ * aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each
course as a separate column
+ * df.groupBy($"year").pivot($"course", "dotNET",
"Java").agg(sum($"earnings"))
+ * // Or without specifying column values
+ * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be
translated to columns in the
+ * output data frame. If values are not provided the method
with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType
match {
+ case _: GroupedData.PivotType =>
+ throw new UnsupportedOperationException("repeated pivots are not
supported")
+ case GroupedData.GroupByType =>
+ val pivotValues = if (values.nonEmpty) {
+ values.map {
+ case Column(literal: Literal) => literal
+ case other =>
+ throw new UnsupportedOperationException(
+ s"The values of a pivot must be literals, found $other")
+ }
+ } else {
+ // This is to prevent unintended OOM errors when the number of
distinct values is large
+ val maxValues =
df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
+ // Get the distinct values of the column and sort them so its
consistent
+ val values = df.select(pivotColumn)
+ .distinct()
+ .sort(pivotColumn)
+ .map(_.get(0))
+ .take(maxValues + 1)
+ .map(Literal(_)).toSeq
+ if (values.length > maxValues) {
+ throw new RuntimeException(
+ s"The pivot column $pivotColumn has more than $maxValues distinct
values, " +
+ "this could indicate an error. " +
+ "If this was intended, set \"" +
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
+ s"to at least the number of distinct values of the pivot
column.")
+ }
+ values
+ }
+ new GroupedData(df, groupingExprs,
GroupedData.PivotType(pivotColumn.expr, pivotValues))
+ case _ =>
+ throw new UnsupportedOperationException("pivot is only supported after a
groupBy")
+ }
+
+ /**
+ * Pivots a column of the current [[DataFrame]] and preform the specified
aggregation.
+ * {{{
+ * // Compute the sum of earnings for each year by course with each
course as a separate column
+ * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
+ * // Or without specifying column values
+ * df.groupBy("year").pivot("course").sum("earnings")
+ * }}}
+ * @param pivotColumn Column to pivot
+ * @param values Optional list of values of pivotColumn that will be
translated to columns in the
+ * output data frame. If values are not provided the method
with do an immediate
+ * call to .distinct() on the pivot column.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def pivot(pivotColumn: String, values: Any*): GroupedData = {
+ val resolvedPivotColumn = Column(df.resolve(pivotColumn))
+ pivot(resolvedPivotColumn, values.map(functions.lit): _*)
+ }
}
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
* To indicate it's the ROLLUP
*/
private[sql] object RollupType extends GroupType
+
+ /**
+ * To indicate it's the PIVOT
+ */
+ private[sql] case class PivotType(pivotCol: Expression, values:
Seq[Literal]) extends GroupType
}
http://git-wip-us.apache.org/repos/asf/spark/blob/b8ff6888/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index e02b502..41d28d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -437,6 +437,13 @@ private[spark] object SQLConf {
defaultValue = Some(true),
isPublic = false)
+ val DATAFRAME_PIVOT_MAX_VALUES = intConf(
+ "spark.sql.pivotMaxValues",
+ defaultValue = Some(10000),
+ doc = "When doing a pivot without specifying values for the pivot column
this is the maximum " +
+ "number of (distinct) values that will be collected without error."
+ )
+
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
defaultValue = Some(true),
isPublic = false,
http://git-wip-us.apache.org/repos/asf/spark/blob/b8ff6888/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
new file mode 100644
index 0000000..0c23d14
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DataFramePivotSuite extends QueryTest with SharedSQLContext{
+ import testImplicits._
+
+ test("pivot courses with literals") {
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ .agg(sum($"earnings")),
+ Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with literals") {
+ checkAnswer(
+ courseSales.groupBy($"course").pivot($"year", lit(2012),
lit(2013)).agg(sum($"earnings")),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot courses with literals and multiple aggregations") {
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
+ .agg(sum($"earnings"), avg($"earnings")),
+ Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
+ Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with string values (cast)") {
+ checkAnswer(
+ courseSales.groupBy("course").pivot("year", "2012",
"2013").sum("earnings"),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot year with int values") {
+ checkAnswer(
+ courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot courses with no values") {
+ // Note Java comes before dotNet in sorted order
+ checkAnswer(
+ courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
+ Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
+ )
+ }
+
+ test("pivot year with no values") {
+ checkAnswer(
+ courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
+ Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
+ )
+ }
+
+ test("pivot max values inforced") {
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
+ intercept[RuntimeException](
+ courseSales.groupBy($"year").pivot($"course")
+ )
+ sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
+ SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
+ }
+}
http://git-wip-us.apache.org/repos/asf/spark/blob/b8ff6888/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 520dea7..abad0d7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
df
}
+ protected lazy val courseSales: DataFrame = {
+ val df = sqlContext.sparkContext.parallelize(
+ CourseSales("dotNET", 2012, 10000) ::
+ CourseSales("Java", 2012, 20000) ::
+ CourseSales("dotNET", 2012, 5000) ::
+ CourseSales("dotNET", 2013, 48000) ::
+ CourseSales("Java", 2013, 30000) :: Nil).toDF()
+ df.registerTempTable("courseSales")
+ df
+ }
+
/**
* Initialize all test data such that all temp tables are properly
registered.
*/
@@ -295,4 +306,5 @@ private[sql] object SQLTestData {
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b:
Boolean)
+ case class CourseSales(course: String, year: Int, earnings: Double)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]