This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 55c3347c48f [SPARK-38864][SQL] Add unpivot / melt to Dataset
55c3347c48f is described below
commit 55c3347c48f93a9c5c5c2fb00b30f838eb081b7f
Author: Enrico Minack <[email protected]>
AuthorDate: Tue Jul 26 15:50:03 2022 +0800
[SPARK-38864][SQL] Add unpivot / melt to Dataset
### What changes were proposed in this pull request?
This proposes a Scala implementation of the `melt` (aka. `unpivot`)
operation.
### Why are the changes needed?
The Scala Dataset API provides the `pivot` operation, but not its reverse
operation `unpivot` or `melt`. The `melt` operation is part of the [Pandas
API](https://pandas.pydata.org/docs/reference/api/pandas.melt.html), which is
why this method is provided by PySpark Pandas API, implemented purely in Python.
[It should be implemented in
Scala](https://github.com/apache/spark/pull/26912#pullrequestreview-332975715)
to make this operation available to Scala / Java, SQL, PySpark, and to reuse
the implementation in PySpark Pandas APIs.
The `melt` / `unpivot` operation exists in other systems like
[BigQuery](https://cloud.google.com/bigquery/docs/reference/standard-sql/query-syntax#unpivot_operator),
[T-SQL](https://docs.microsoft.com/en-us/sql/t-sql/queries/from-using-pivot-and-unpivot?view=sql-server-ver15#unpivot-example),
[Oracle](https://www.oracletutorial.com/oracle-basics/oracle-unpivot/).
It supports expressions for ids and value columns including `*` expansion
and structs. So this also fixes / includes SPARK-39292.
### Does this PR introduce _any_ user-facing change?
It adds `melt` to the `Dataset` API (Scala and Java).
### How was this patch tested?
It is tested in the `DatasetMeltSuite` and `JavaDatasetSuite`.
Closes #36150 from EnricoMi/branch-melt.
Authored-by: Enrico Minack <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
core/src/main/resources/error/error-classes.json | 12 +
.../spark/sql/catalyst/analysis/Analyzer.scala | 41 ++
.../sql/catalyst/analysis/AnsiTypeCoercion.scala | 1 +
.../sql/catalyst/analysis/CheckAnalysis.scala | 8 +
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 16 +
.../plans/logical/basicLogicalOperators.scala | 39 ++
.../sql/catalyst/rules/RuleIdCollection.scala | 1 +
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../spark/sql/errors/QueryCompilationErrors.scala | 18 +
.../main/scala/org/apache/spark/sql/Dataset.scala | 138 +++++-
.../spark/sql/RelationalGroupedDataset.scala | 18 +
.../org/apache/spark/sql/DatasetUnpivotSuite.scala | 543 +++++++++++++++++++++
.../spark/sql/errors/QueryErrorsSuiteBase.scala | 3 +-
13 files changed, 837 insertions(+), 2 deletions(-)
diff --git a/core/src/main/resources/error/error-classes.json
b/core/src/main/resources/error/error-classes.json
index e2a99c1a62e..29ca280719e 100644
--- a/core/src/main/resources/error/error-classes.json
+++ b/core/src/main/resources/error/error-classes.json
@@ -375,6 +375,18 @@
"Unable to acquire <requestedBytes> bytes of memory, got <receivedBytes>"
]
},
+ "UNPIVOT_REQUIRES_VALUE_COLUMNS" : {
+ "message" : [
+ "At least one value column needs to be specified for UNPIVOT, all
columns specified as ids"
+ ],
+ "sqlState" : "42000"
+ },
+ "UNPIVOT_VALUE_DATA_TYPE_MISMATCH" : {
+ "message" : [
+ "Unpivot value columns must share a least common type, some types do
not: [<types>]"
+ ],
+ "sqlState" : "42000"
+ },
"UNRECOGNIZED_SQL_TYPE" : {
"message" : [
"Unrecognized SQL type <typeName>"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index f40c260eb6f..a6108c2a3d3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -293,6 +293,7 @@ class Analyzer(override val catalogManager: CatalogManager)
ResolveUpCast ::
ResolveGroupingAnalytics ::
ResolvePivot ::
+ ResolveUnpivot ::
ResolveOrdinalInOrderByAndGroupBy ::
ResolveAggAliasInGroupBy ::
ResolveMissingReferences ::
@@ -514,6 +515,10 @@ class Analyzer(override val catalogManager: CatalogManager)
if child.resolved && groupByOpt.isDefined &&
hasUnresolvedAlias(groupByOpt.get) =>
Pivot(Some(assignAliases(groupByOpt.get)), pivotColumn, pivotValues,
aggregates, child)
+ case up: Unpivot if up.child.resolved &&
+ (hasUnresolvedAlias(up.ids) || hasUnresolvedAlias(up.values)) =>
+ up.copy(ids = assignAliases(up.ids), values = assignAliases(up.values))
+
case Project(projectList, child) if child.resolved &&
hasUnresolvedAlias(projectList) =>
Project(assignAliases(projectList), child)
@@ -859,6 +864,36 @@ class Analyzer(override val catalogManager: CatalogManager)
}
}
+ object ResolveUnpivot extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
+ _.containsPattern(UNPIVOT), ruleId) {
+
+ // once children and ids are resolved, we can determine values, if non
were given
+ case up: Unpivot if up.childrenResolved && up.ids.forall(_.resolved) &&
up.values.isEmpty =>
+ up.copy(values = up.child.output.diff(up.ids))
+
+ case up: Unpivot if !up.childrenResolved || !up.ids.forall(_.resolved) ||
+ up.values.isEmpty || !up.values.forall(_.resolved) ||
!up.valuesTypeCoercioned => up
+
+ // TypeCoercionBase.UnpivotCoercion determines valueType
+ // and casts values once values are set and resolved
+ case Unpivot(ids, values, variableColumnName, valueColumnName, child) =>
+ // construct unpivot expressions for Expand
+ val exprs: Seq[Seq[Expression]] = values.map {
+ value => ids ++ Seq(Literal(value.name), value)
+ }
+
+ // construct output attributes
+ val output = ids.map(_.toAttribute) ++ Seq(
+ AttributeReference(variableColumnName, StringType, nullable =
false)(),
+ AttributeReference(valueColumnName, values.head.dataType,
values.exists(_.nullable))()
+ )
+
+ // expand the unpivot expressions
+ Expand(exprs, output, child)
+ }
+ }
+
private def isResolvingView: Boolean =
AnalysisContext.get.catalogAndNamespace.nonEmpty
private def isReferredTempViewName(nameParts: Seq[String]): Boolean = {
AnalysisContext.get.referredTempViewNames.exists { n =>
@@ -1349,6 +1384,12 @@ class Analyzer(override val catalogManager:
CatalogManager)
case g: Generate if containsStar(g.generator.children) =>
throw
QueryCompilationErrors.invalidStarUsageError("explode/json_tuple/UDTF",
extractStar(g.generator.children))
+ // If the Unpivot ids or values contain Stars, expand them.
+ case up: Unpivot if containsStar(up.ids) || containsStar(up.values) =>
+ up.copy(
+ ids = buildExpandedProjectList(up.ids, up.child),
+ values = buildExpandedProjectList(up.values, up.child)
+ )
case u @ Union(children, _, _)
// if there are duplicate output columns, give them unique expr ids
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index fd3885fe834..56dbb2a8590 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -74,6 +74,7 @@ import org.apache.spark.sql.types._
*/
object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
+ UnpivotCoercion ::
WidenSetOperationTypes ::
new AnsiCombinedTypeCoercionRule(
InConversion ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index cf734b7aa26..3f5b535b947 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -422,6 +422,14 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog {
}
metrics.foreach(m => checkMetric(m, m))
+ // see Analyzer.ResolveUnpivot
+ case up: Unpivot
+ if up.childrenResolved && up.ids.forall(_.resolved) &&
up.values.isEmpty =>
+ throw QueryCompilationErrors.unpivotRequiresValueColumns()
+ // see TypeCoercionBase.UnpivotCoercion
+ case up: Unpivot if !up.valuesTypeCoercioned =>
+ throw
QueryCompilationErrors.unpivotValDataTypeMismatchError(up.values)
+
case Sort(orders, _, _) =>
orders.foreach { order =>
if (!RowOrdering.isOrderable(order.dataType)) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index c3db4787eca..4e66c87f361 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -198,6 +198,21 @@ abstract class TypeCoercionBase {
}
}
+ /**
+ * Widens the data types of the [[Unpivot]] values.
+ */
+ object UnpivotCoercion extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators
{
+ case up: Unpivot
+ if up.values.nonEmpty && up.values.forall(_.resolved) &&
!up.valuesTypeCoercioned =>
+ val valueDataType =
findWiderTypeWithoutStringPromotion(up.values.map(_.dataType))
+ val values = valueDataType.map(valueType =>
+ up.values.map(value => Alias(Cast(value, valueType), value.name)())
+ ).getOrElse(up.values)
+ up.copy(values = values)
+ }
+ }
+
/**
* Widens the data types of the children of Union/Except/Intersect.
* 1. When ANSI mode is off:
@@ -806,6 +821,7 @@ abstract class TypeCoercionBase {
object TypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
+ UnpivotCoercion ::
WidenSetOperationTypes ::
new CombinedTypeCoercionRule(
InConversion ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index bdc7bf9bd7d..22134a06288 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1354,6 +1354,45 @@ case class Pivot(
override protected def withNewChildInternal(newChild: LogicalPlan): Pivot =
copy(child = newChild)
}
+/**
+ * A constructor for creating an Unpivot, which will later be converted to an
[[Expand]]
+ * during the query analysis.
+ *
+ * An empty values array will be replaced during analysis with all resolved
outputs of child except
+ * the ids. This expansion allows to easily unpivot all non-id columns.
+ *
+ * @see `org.apache.spark.sql.catalyst.analysis.Analyzer.ResolveUnpivot`
+ *
+ * The type of the value column is derived from all value columns during
analysis once all values
+ * are resolved. All values' types have to be compatible, otherwise the result
value column cannot
+ * be assigned the individual values and an AnalysisException is thrown.
+ *
+ * @see
`org.apache.spark.sql.catalyst.analysis.TypeCoercionBase.UnpivotCoercion`
+ *
+ * @param ids Id columns
+ * @param values Value columns to unpivot
+ * @param variableColumnName Name of the variable column
+ * @param valueColumnName Name of the value column
+ * @param child Child operator
+ */
+case class Unpivot(
+ ids: Seq[NamedExpression],
+ values: Seq[NamedExpression],
+ variableColumnName: String,
+ valueColumnName: String,
+ child: LogicalPlan) extends UnaryNode {
+ override lazy val resolved = false // Unpivot will be replaced after being
resolved.
+ override def output: Seq[Attribute] = Nil
+ override def metadataOutput: Seq[Attribute] = Nil
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNPIVOT)
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): Unpivot =
+ copy(child = newChild)
+
+ def valuesTypeCoercioned: Boolean = values.nonEmpty &&
values.forall(_.resolved) &&
+ values.tail.forall(v => v.dataType.sameType(values.head.dataType))
+}
+
/**
* A constructor for creating a logical limit, which is split into two
separate logical nodes:
* a [[LocalLimit]], which is a partition local limit, followed by a
[[GlobalLimit]].
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 2f118db8248..eda6ff60e61 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -71,6 +71,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveSubqueryColumnAliases"
::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTables" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveTempViews" ::
+ "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUnpivot" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUserSpecifiedColumns" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveWindowFrame" ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 93273b5a2c7..3342f11a0fa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -87,6 +87,7 @@ object TreePattern extends Enumeration {
val TRUE_OR_FALSE_LITERAL: Value = Value
val WINDOW_EXPRESSION: Value = Value
val UNARY_POSITIVE: Value = Value
+ val UNPIVOT: Value = Value
val UPDATE_FIELDS: Value = Value
val UPPER_OR_LOWER: Value = Value
val UP_CAST: Value = Value
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index c828318f2cd..c344c64997f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -92,6 +92,24 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase {
pivotVal.toString, pivotVal.dataType.simpleString,
pivotCol.dataType.catalogString))
}
+ def unpivotRequiresValueColumns(): Throwable = {
+ new AnalysisException(
+ errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS",
+ messageParameters = Array.empty)
+ }
+
+ def unpivotValDataTypeMismatchError(values: Seq[NamedExpression]): Throwable
= {
+ val dataTypes = values
+ .groupBy(_.dataType)
+ .mapValues(values => values.map(value => toSQLId(value.toString)))
+ .mapValues(values => if (values.length > 3) values.take(3) :+ "..." else
values)
+ .map { case (dataType, values) => s"${toSQLType(dataType)}
(${values.mkString(", ")})" }
+
+ new AnalysisException(
+ errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH",
+ messageParameters = Array(dataTypes.mkString(", ")))
+ }
+
def unsupportedIfNotExistsError(tableName: String): Throwable = {
new AnalysisException(
errorClass = "UNSUPPORTED_FEATURE",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index bc0b37e5923..49b4a8389f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1065,7 +1065,7 @@ class Dataset[T] private[sql](
* @param joinType Type of join to perform. Default `inner`. Must be one of:
* `inner`, `cross`, `outer`, `full`, `fullouter`,
`full_outer`, `left`,
* `leftouter`, `left_outer`, `right`, `rightouter`,
`right_outer`,
- * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
left_anti`.
+ * `semi`, `leftsemi`, `left_semi`, `anti`, `leftanti`,
`left_anti`.
*
* @note If you perform a self-join using this function without aliasing the
input
* `DataFrame`s, you will NOT be able to reference any columns after the
join, since
@@ -2036,6 +2036,142 @@ class Dataset[T] private[sql](
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs
: _*)
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving
identifier columns set.
+ * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the
aggregation,
+ * which cannot be reversed.
+ *
+ * This function is useful to massage a DataFrame into a format where some
+ * columns are identifier columns ("ids"), while all other columns ("values")
+ * are "unpivoted" to the rows, leaving just two non-id columns, named as
given
+ * by `variableColumnName` and `valueColumnName`.
+ *
+ * {{{
+ * val df = Seq((1, 11, 12L), (2, 21, 22L)).toDF("id", "int", "long")
+ * df.show()
+ * // output:
+ * // +---+---+----+
+ * // | id|int|long|
+ * // +---+---+----+
+ * // | 1| 11| 12|
+ * // | 2| 21| 22|
+ * // +---+---+----+
+ *
+ * df.unpivot(Array($"id"), Array($"int", $"long"), "variable",
"value").show()
+ * // output:
+ * // +---+--------+-----+
+ * // | id|variable|value|
+ * // +---+--------+-----+
+ * // | 1| int| 11|
+ * // | 1| long| 12|
+ * // | 2| int| 21|
+ * // | 2| long| 22|
+ * // +---+--------+-----+
+ * // schema:
+ * //root
+ * // |-- id: integer (nullable = false)
+ * // |-- variable: string (nullable = false)
+ * // |-- value: long (nullable = true)
+ * }}}
+ *
+ * When no "id" columns are given, the unpivoted DataFrame consists of only
the
+ * "variable" and "value" columns.
+ *
+ * All "value" columns must share a least common data type. Unless they are
the same data type,
+ * all "value" columns are cast to the nearest common data type. For
instance,
+ * types `IntegerType` and `LongType` are cast to `LongType`, while
`IntegerType` and `StringType`
+ * do not have a common data type and `unpivot` fails.
+ *
+ * @param ids Id columns
+ * @param values Value columns to unpivot
+ * @param variableColumnName Name of the variable column
+ * @param valueColumnName Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def unpivot(
+ ids: Array[Column],
+ values: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame = withPlan {
+ Unpivot(
+ ids.map(_.named),
+ values.map(_.named),
+ variableColumnName,
+ valueColumnName,
+ logicalPlan
+ )
+ }
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving
identifier columns set.
+ * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the
aggregation,
+ * which cannot be reversed.
+ *
+ * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * This is equivalent to calling `Dataset#unpivot(Array, Array, String,
String)`
+ * where `values` is set to all non-id columns that exist in the DataFrame.
+ *
+ * @param ids Id columns
+ * @param variableColumnName Name of the variable column
+ * @param valueColumnName Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def unpivot(
+ ids: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame =
+ unpivot(ids, Array.empty, variableColumnName, valueColumnName)
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving
identifier columns set.
+ * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the
aggregation,
+ * which cannot be reversed. This is an alias for `unpivot`.
+ *
+ * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * @param ids Id columns
+ * @param values Value columns to unpivot
+ * @param variableColumnName Name of the variable column
+ * @param valueColumnName Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def melt(
+ ids: Array[Column],
+ values: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame =
+ unpivot(ids, values, variableColumnName, valueColumnName)
+
+ /**
+ * Unpivot a DataFrame from wide format to long format, optionally leaving
identifier columns set.
+ * This is the reverse to `groupBy(...).pivot(...).agg(...)`, except for the
aggregation,
+ * which cannot be reversed. This is an alias for `unpivot`.
+ *
+ * @see `org.apache.spark.sql.Dataset.unpivot(Array, Array, String, String)`
+ *
+ * This is equivalent to calling `Dataset#unpivot(Array, Array, String,
String)`
+ * where `values` is set to all non-id columns that exist in the DataFrame.
+ *
+ * @param ids Id columns
+ * @param variableColumnName Name of the variable column
+ * @param valueColumnName Name of the value column
+ *
+ * @group untypedrel
+ * @since 3.4.0
+ */
+ def melt(
+ ids: Array[Column],
+ variableColumnName: String,
+ valueColumnName: String): DataFrame =
+ unpivot(ids, variableColumnName, valueColumnName)
+
/**
* Define (named) metrics to observe on the Dataset. This method returns an
'observed' Dataset
* that returns the same result as the input, with the following guarantees:
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 7e3c6221961..989ee325218 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -343,6 +343,9 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
+ * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
+ * except for the aggregation.
+ *
* @param pivotColumn Name of the column to pivot.
* @since 1.6.0
*/
@@ -371,6 +374,9 @@ class RelationalGroupedDataset protected[sql](
* .agg(sum($"earnings"))
* }}}
*
+ * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
+ * except for the aggregation.
+ *
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the
output DataFrame.
* @since 1.6.0
@@ -395,6 +401,9 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
+ * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
+ * except for the aggregation.
+ *
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the
output DataFrame.
* @since 1.6.0
@@ -412,6 +421,9 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy($"year").pivot($"course").sum($"earnings");
* }}}
*
+ * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
+ * except for the aggregation.
+ *
* @param pivotColumn he column to pivot.
* @since 2.4.0
*/
@@ -444,6 +456,9 @@ class RelationalGroupedDataset protected[sql](
* df.groupBy($"year").pivot($"course", Seq("dotNET",
"Java")).sum($"earnings")
* }}}
*
+ * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
+ * except for the aggregation.
+ *
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the
output DataFrame.
* @since 2.4.0
@@ -477,6 +492,9 @@ class RelationalGroupedDataset protected[sql](
* aggregation. This is an overloaded version of the `pivot` method with
`pivotColumn` of
* the `String` type.
*
+ * @see `org.apache.spark.sql.Dataset.unpivot` for the reverse operation,
+ * except for the aggregation.
+ *
* @param pivotColumn the column to pivot.
* @param values List of values that will be translated to columns in the
output DataFrame.
* @since 2.4.0
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala
new file mode 100644
index 00000000000..8ccad457e8d
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetUnpivotSuite.scala
@@ -0,0 +1,543 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.errors.QueryErrorsSuiteBase
+import org.apache.spark.sql.functions.{length, struct, sum}
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+/**
+ * Comprehensive tests for Dataset.unpivot.
+ */
+class DatasetUnpivotSuite extends QueryTest
+ with QueryErrorsSuiteBase
+ with SharedSparkSession {
+ import testImplicits._
+
+ lazy val wideDataDs: Dataset[WideData] = Seq(
+ WideData(1, "one", "One", Some(1), Some(1L)),
+ WideData(2, "two", null, None, Some(2L)),
+ WideData(3, null, "three", Some(3), None),
+ WideData(4, null, null, None, None)
+ ).toDS()
+
+ val longDataRows = Seq(
+ Row(1, "str1", "one"),
+ Row(1, "str2", "One"),
+ Row(2, "str1", "two"),
+ Row(2, "str2", null),
+ Row(3, "str1", null),
+ Row(3, "str2", "three"),
+ Row(4, "str1", null),
+ Row(4, "str2", null)
+ )
+
+ val longDataWithoutIdRows: Seq[Row] =
+ longDataRows.map(row => Row(row.getString(1), row.getString(2)))
+
+ val longSchema: StructType = StructType(Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("var", StringType, nullable = false),
+ StructField("val", StringType, nullable = true)
+ ))
+
+ lazy val wideStructDataDs: DataFrame = wideDataDs.select(
+ struct($"id").as("an"),
+ struct(
+ $"str1".as("one"),
+ $"str2".as("two")
+ ).as("str")
+ )
+ val longStructDataRows: Seq[Row] = longDataRows.map(row =>
+ Row(
+ row.getInt(0),
+ row.getString(1) match {
+ case "str1" => "one"
+ case "str2" => "two"
+ },
+ row.getString(2))
+ )
+
+ test("overloaded unpivot without values") {
+ val ds = wideDataDs.select($"id", $"str1", $"str2")
+ checkAnswer(
+ ds.unpivot(Array($"id"), "var", "val"),
+ ds.unpivot(Array($"id"), Array.empty, "var", "val"))
+ }
+
+ test("unpivot with single id") {
+ val unpivoted = wideDataDs
+ .unpivot(
+ Array($"id"),
+ Array($"str1", $"str2"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === longSchema)
+ checkAnswer(unpivoted, longDataRows)
+ }
+
+ test("unpivot with two ids") {
+ val unpivotedRows = Seq(
+ Row(1, 1, "str1", "one"),
+ Row(1, 1, "str2", "One"),
+ Row(2, null, "str1", "two"),
+ Row(2, null, "str2", null),
+ Row(3, 3, "str1", null),
+ Row(3, 3, "str2", "three"),
+ Row(4, null, "str1", null),
+ Row(4, null, "str2", null))
+
+ val unpivoted = wideDataDs
+ .unpivot(
+ Array($"id", $"int1"),
+ Array($"str1", $"str2"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("int1", IntegerType, nullable = true),
+ StructField("var", StringType, nullable = false),
+ StructField("val", StringType, nullable = true))))
+ checkAnswer(unpivoted, unpivotedRows)
+ }
+
+ test("unpivot without ids") {
+ val unpivoted = wideDataDs
+ .unpivot(
+ Array.empty,
+ Array($"str1", $"str2"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("var", StringType, nullable = false),
+ StructField("val", StringType, nullable = true))))
+ checkAnswer(unpivoted, longDataWithoutIdRows)
+ }
+
+ test("unpivot without values") {
+ val unpivoted = wideDataDs.select($"id", $"str1", $"str2")
+ .unpivot(
+ Array($"id"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === longSchema)
+ checkAnswer(unpivoted, longDataRows)
+
+ val unpivoted2 = wideDataDs.select($"id", $"str1", $"str2")
+ .unpivot(
+ Array($"id"),
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted2.schema === longSchema)
+ checkAnswer(unpivoted2, longDataRows)
+
+ val unpivotedRows = Seq(
+ Row(1, "id", 1L),
+ Row(1, "int1", 1L),
+ Row(1, "long1", 1L),
+ Row(2, "id", 2L),
+ Row(2, "int1", null),
+ Row(2, "long1", 2L),
+ Row(3, "id", 3L),
+ Row(3, "int1", 3L),
+ Row(3, "long1", null),
+ Row(4, "id", 4L),
+ Row(4, "int1", null),
+ Row(4, "long1", null)
+ )
+
+ val unpivoted3 = wideDataDs.select($"id", $"int1", $"long1")
+ .unpivot(
+ Array($"id" * 2),
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted3.schema === StructType(Seq(
+ StructField("(id * 2)", IntegerType, nullable = false),
+ StructField("var", StringType, nullable = false),
+ StructField("val", LongType, nullable = true)
+ )))
+ checkAnswer(unpivoted3, unpivotedRows.map(row =>
+ Row(row.getInt(0) * 2, row.get(1), row.get(2))))
+
+ val unpivoted4 = wideDataDs.select($"id", $"int1", $"long1")
+ .unpivot(
+ Array($"id".as("uid")),
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted4.schema === StructType(Seq(
+ StructField("uid", IntegerType, nullable = false),
+ StructField("var", StringType, nullable = false),
+ StructField("val", LongType, nullable = true)
+ )))
+ checkAnswer(unpivoted4, unpivotedRows)
+ }
+
+ test("unpivot without ids or values") {
+ val unpivoted = wideDataDs.select($"str1", $"str2")
+ .unpivot(
+ Array.empty,
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("var", StringType, nullable = false),
+ StructField("val", StringType, nullable = true))))
+ checkAnswer(unpivoted, longDataWithoutIdRows)
+ }
+
+ test("unpivot with star values") {
+ val unpivoted = wideDataDs.select($"str1", $"str2")
+ .unpivot(
+ Array.empty,
+ Array($"*"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("var", StringType, nullable = false),
+ StructField("val", StringType, nullable = true))))
+ checkAnswer(unpivoted, longDataWithoutIdRows)
+ }
+
+ test("unpivot with id and star values") {
+ val unpivoted = wideDataDs.select($"id", $"int1", $"long1")
+ .unpivot(
+ Array($"id"),
+ Array($"*"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("var", StringType, nullable = false),
+ StructField("val", LongType, nullable = true))))
+
+ checkAnswer(unpivoted, wideDataDs.collect().flatMap { row => Seq(
+ Row(row.id, "id", row.id),
+ Row(row.id, "int1", row.int1.orNull),
+ Row(row.id, "long1", row.long1.orNull)
+ )})
+ }
+
+ test("unpivot with expressions") {
+ // ids and values are all expressions (computed)
+ val unpivoted = wideDataDs
+ .unpivot(
+ Array(($"id" * 10).as("primary"), $"str1".as("secondary")),
+ Array(($"int1" + $"long1").as("sum"), length($"str2").as("len")),
+ variableColumnName = "var",
+ valueColumnName = "val")
+
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("primary", IntegerType, nullable = false),
+ StructField("secondary", StringType, nullable = true),
+ StructField("var", StringType, nullable = false),
+ StructField("val", LongType, nullable = true))))
+
+ checkAnswer(unpivoted, wideDataDs.collect().flatMap { row =>
+ Seq(
+ Row(
+ row.id * 10,
+ row.str1,
+ "sum",
+ // sum of int1 and long1 when both are set, or null otherwise
+ row.int1.flatMap(i => row.long1.map(l => i + l)).orNull),
+ Row(
+ row.id * 10,
+ row.str1,
+ "len",
+ // length of str2 if set, or null otherwise
+ Option(row.str2).map(_.length).orNull)
+ )
+ })
+ }
+
+ test("unpivot with variable / value columns") {
+ // with value column `variable` and `value`
+ val unpivoted = wideDataDs
+ .withColumnRenamed("str1", "var")
+ .withColumnRenamed("str2", "val")
+ .unpivot(
+ Array($"id"),
+ Array($"var", $"val"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ checkAnswer(unpivoted, longDataRows.map(row => Row(
+ row.getInt(0),
+ row.getString(1) match {
+ case "str1" => "var"
+ case "str2" => "val"
+ },
+ row.getString(2))))
+ }
+
+ test("unpivot with incompatible value types") {
+ val e = intercept[AnalysisException] {
+ wideDataDs
+ .select(
+ $"id",
+ $"str1",
+ $"int1", $"int1".as("int2"), $"int1".as("int3"), $"int1".as("int4"),
+ $"long1", $"long1".as("long2")
+ )
+ .unpivot(
+ Array($"id"),
+ Array(),
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e,
+ errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH",
+ msg = "Unpivot value columns must share a least common type, some types
do not: \\[" +
+ "\"STRING\" \\(`str1#\\d+`\\), " +
+ "\"INT\" \\(`int1#\\d+`, `int2#\\d+`, `int3#\\d+`, ...\\), " +
+ "\"BIGINT\" \\(`long1#\\d+L`, `long2#\\d+L`\\)\\];(\n.*)*",
+ matchMsg = true)
+ }
+
+ test("unpivot with compatible value types") {
+ val unpivoted = wideDataDs.unpivot(
+ Array($"id"),
+ Array($"int1", $"long1"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ assert(unpivoted.schema === StructType(Seq(
+ StructField("id", IntegerType, nullable = false),
+ StructField("var", StringType, nullable = false),
+ StructField("val", LongType, nullable = true)
+ )))
+
+ val unpivotedRows = Seq(
+ Row(1, "int1", 1L),
+ Row(1, "long1", 1L),
+ Row(2, "int1", null),
+ Row(2, "long1", 2L),
+ Row(3, "int1", 3L),
+ Row(3, "long1", null),
+ Row(4, "int1", null),
+ Row(4, "long1", null)
+ )
+ checkAnswer(unpivoted, unpivotedRows)
+ }
+
+ test("unpivot and drop nulls") {
+ checkAnswer(
+ wideDataDs
+ .unpivot(Array($"id"), Array($"str1", $"str2"), "var", "val")
+ .where($"val".isNotNull),
+ longDataRows.filter(_.getString(2) != null))
+ }
+
+ test("unpivot with invalid arguments") {
+ // unpivoting where id column does not exist
+ val e1 = intercept[AnalysisException] {
+ wideDataDs.unpivot(
+ Array($"1", $"2"),
+ Array($"str1", $"str2"),
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e1,
+ errorClass = "UNRESOLVED_COLUMN",
+ msg = "A column or function parameter with name `1` cannot be
resolved\\. " +
+ "Did you mean one of the following\\? \\[`id`, `int1`, `str1`, `str2`,
`long1`\\];(\n.*)*",
+ matchMsg = true)
+
+ // unpivoting where value column does not exist
+ val e2 = intercept[AnalysisException] {
+ wideDataDs.unpivot(
+ Array($"id"),
+ Array($"does", $"not", $"exist"),
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e2,
+ errorClass = "UNRESOLVED_COLUMN",
+ msg = "A column or function parameter with name `does` cannot be
resolved\\. " +
+ "Did you mean one of the following\\? \\[`id`, `int1`, `long1`,
`str1`, `str2`\\];(\n.*)*",
+ matchMsg = true)
+
+ // unpivoting with empty list of value columns
+ // where potential value columns are of incompatible types
+ val e3 = intercept[AnalysisException] {
+ wideDataDs.unpivot(
+ Array.empty,
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e3,
+ errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH",
+ msg = "Unpivot value columns must share a least common type, some types
do not: \\[" +
+ "\"INT\" \\(`id#\\d+`, `int1#\\d+`\\), " +
+ "\"STRING\" \\(`str1#\\d+`, `str2#\\d+`\\), " +
+ "\"BIGINT\" \\(`long1#\\d+L`\\)\\];(\n.*)*",
+ matchMsg = true)
+
+ // unpivoting with star id columns so that no value columns are left
+ val e4 = intercept[AnalysisException] {
+ wideDataDs.unpivot(
+ Array($"*"),
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e4,
+ errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS",
+ msg = "At least one value column needs to be specified for UNPIVOT, " +
+ "all columns specified as ids;(\\n.*)*",
+ matchMsg = true)
+
+ // unpivoting with star value columns
+ // where potential value columns are of incompatible types
+ val e5 = intercept[AnalysisException] {
+ wideDataDs.unpivot(
+ Array.empty,
+ Array($"*"),
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e5,
+ errorClass = "UNPIVOT_VALUE_DATA_TYPE_MISMATCH",
+ msg = "Unpivot value columns must share a least common type, some types
do not: \\[" +
+ "\"INT\" \\(`id#\\d+`, `int1#\\d+`\\), " +
+ "\"STRING\" \\(`str1#\\d+`, `str2#\\d+`\\), " +
+ "\"BIGINT\" \\(`long1#\\d+L`\\)\\];(\n.*)*",
+ matchMsg = true)
+
+ // unpivoting without giving values and no non-id columns
+ val e6 = intercept[AnalysisException] {
+ wideDataDs.select($"id", $"str1", $"str2").unpivot(
+ Array($"id", $"str1", $"str2"),
+ Array.empty,
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e6,
+ errorClass = "UNPIVOT_REQUIRES_VALUE_COLUMNS",
+ msg = "At least one value column needs to be specified for UNPIVOT, " +
+ "all columns specified as ids;(\\n.*)*",
+ matchMsg = true)
+ }
+
+ test("unpivot after pivot") {
+ // see test "pivot courses" in DataFramePivotSuite
+ val pivoted = courseSales.groupBy("year").pivot("course", Array("dotNET",
"Java"))
+ .agg(sum($"earnings"))
+ val unpivoted = pivoted.unpivot(Array($"year"), "course", "earnings")
+ val expected = courseSales.groupBy("year", "course").sum("earnings")
+ checkAnswer(unpivoted, expected)
+ }
+
+ test("unpivot of unpivot") {
+ checkAnswer(
+ wideDataDs
+ .unpivot(Array($"id"), Array($"str1", $"str2"), "var", "val")
+ .unpivot(Array($"id"), Array($"var", $"val"), "col", "value"),
+ longDataRows.flatMap(row => Seq(
+ Row(row.getInt(0), "var", row.getString(1)),
+ Row(row.getInt(0), "val", row.getString(2)))))
+ }
+
+ test("unpivot with dot and backtick") {
+ val ds = wideDataDs
+ .withColumnRenamed("id", "an.id")
+ .withColumnRenamed("str1", "str.one")
+ .withColumnRenamed("str2", "str.two")
+
+ val unpivoted = ds.unpivot(
+ Array($"`an.id`"),
+ Array($"`str.one`", $"`str.two`"),
+ variableColumnName = "var",
+ valueColumnName = "val")
+ checkAnswer(unpivoted, longDataRows.map(row => Row(
+ row.getInt(0),
+ row.getString(1) match {
+ case "str1" => "str.one"
+ case "str2" => "str.two"
+ },
+ row.getString(2))))
+
+ // without backticks, this references struct fields, which do not exist
+ val e = intercept[AnalysisException] {
+ ds.unpivot(
+ Array($"an.id"),
+ Array($"str.one", $"str.two"),
+ variableColumnName = "var",
+ valueColumnName = "val"
+ )
+ }
+ checkErrorClass(
+ exception = e,
+ errorClass = "UNRESOLVED_COLUMN",
+ // expected message is wrong:
https://issues.apache.org/jira/browse/SPARK-39783
+ msg = "A column or function parameter with name `an`\\.`id` cannot be
resolved\\. " +
+ "Did you mean one of the following\\? " +
+ "\\[`an`.`id`, `int1`, `long1`, `str`.`one`, `str`.`two`\\];(\n.*)*",
+ matchMsg = true)
+ }
+
+ test("unpivot with struct fields") {
+ checkAnswer(
+ wideStructDataDs.unpivot(
+ Array($"an.id"),
+ Array($"str.one", $"str.two"),
+ "var",
+ "val"),
+ longStructDataRows)
+ }
+
+ test("unpivot with struct ids star") {
+ checkAnswer(
+ wideStructDataDs.unpivot(
+ Array($"an.*"),
+ Array($"str.one", $"str.two"),
+ "var",
+ "val"),
+ longStructDataRows)
+ }
+
+ test("unpivot with struct values star") {
+ checkAnswer(
+ wideStructDataDs.unpivot(
+ Array($"an.id"),
+ Array($"str.*"),
+ "var",
+ "val"),
+ longStructDataRows)
+ }
+}
+
+case class WideData(id: Int, str1: String, str2: String, int1: Option[Int],
long1: Option[Long])
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala
index 895a72efeec..d78a6a91959 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryErrorsSuiteBase.scala
@@ -37,7 +37,8 @@ trait QueryErrorsSuiteBase extends SharedSparkSession {
errorClass
}
if (matchMsg) {
- assert(exception.getMessage.matches(s"""\\[$fullErrorClass\\] """ + msg))
+ assert(exception.getMessage.matches(s"""\\[$fullErrorClass\\] """ + msg),
+ "exception is: " + exception.getMessage)
} else {
assert(exception.getMessage === s"""[$fullErrorClass] """ + msg)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]