Repository: spark
Updated Branches:
refs/heads/master a3afa4a1b -> 5c78be7a5
[SPARK-5799][SQL] Compute aggregation function on specified numeric columns
Compute aggregation function on specified numeric columns. For example:
val df = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3,
"d")).toDataFrame("key", "value1", "value2", "rest")
df.groupBy("key").min("value2")
Author: Liang-Chi Hsieh <[email protected]>
Closes #4592 from viirya/specific_cols_agg and squashes the following commits:
9446896 [Liang-Chi Hsieh] For comments.
314c4cd [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into
specific_cols_agg
353fad7 [Liang-Chi Hsieh] For python unit tests.
54ed0c4 [Liang-Chi Hsieh] Address comments.
b079e6b [Liang-Chi Hsieh] Remove duplicate codes.
55100fb [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into
specific_cols_agg
880c2ac [Liang-Chi Hsieh] Fix Python style checks.
4c63a01 [Liang-Chi Hsieh] Fix pyspark.
b1a24fc [Liang-Chi Hsieh] Address comments.
2592f29 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into
specific_cols_agg
27069c3 [Liang-Chi Hsieh] Combine functions and add varargs annotation.
371a3f7 [Liang-Chi Hsieh] Compute aggregation function on specified numeric
columns.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5c78be7a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5c78be7a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5c78be7a
Branch: refs/heads/master
Commit: 5c78be7a515fc2fc92cda0517318e7b5d85762f4
Parents: a3afa4a
Author: Liang-Chi Hsieh <[email protected]>
Authored: Mon Feb 16 10:06:11 2015 -0800
Committer: Reynold Xin <[email protected]>
Committed: Mon Feb 16 10:06:11 2015 -0800
----------------------------------------------------------------------
python/pyspark/sql/dataframe.py | 74 ++++++++++++++++----
python/pyspark/sql/functions.py | 2 +
.../org/apache/spark/sql/DataFrameImpl.scala | 4 +-
.../org/apache/spark/sql/GroupedData.scala | 57 ++++++++++++---
.../org/apache/spark/sql/DataFrameSuite.scala | 12 ++++
5 files changed, 123 insertions(+), 26 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1438fe5..28a59e7 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -664,6 +664,18 @@ def dfapi(f):
return _api
+def df_varargs_api(f):
+ def _api(self, *args):
+ jargs = ListConverter().convert(args,
+
self.sql_ctx._sc._gateway._gateway_client)
+ name = f.__name__
+ jdf = getattr(self._jdf,
name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
class GroupedData(object):
"""
@@ -714,30 +726,60 @@ class GroupedData(object):
[Row(age=2, count=1), Row(age=5, count=1)]
"""
- @dfapi
- def mean(self):
+ @df_varargs_api
+ def mean(self, *cols):
"""Compute the average value for each numeric columns
- for each group. This is an alias for `avg`."""
+ for each group. This is an alias for `avg`.
- @dfapi
- def avg(self):
+ >>> df.groupBy().mean('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().mean('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def avg(self, *cols):
"""Compute the average value for each numeric columns
- for each group."""
+ for each group.
- @dfapi
- def max(self):
+ >>> df.groupBy().avg('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().avg('age', 'height').collect()
+ [Row(AVG(age#4)=3.5, AVG(height#5)=82.5)]
+ """
+
+ @df_varargs_api
+ def max(self, *cols):
"""Compute the max value for each numeric columns for
- each group. """
+ each group.
- @dfapi
- def min(self):
+ >>> df.groupBy().max('age').collect()
+ [Row(MAX(age#0)=5)]
+ >>> df3.groupBy().max('age', 'height').collect()
+ [Row(MAX(age#4)=5, MAX(height#5)=85)]
+ """
+
+ @df_varargs_api
+ def min(self, *cols):
"""Compute the min value for each numeric column for
- each group."""
+ each group.
- @dfapi
- def sum(self):
+ >>> df.groupBy().min('age').collect()
+ [Row(MIN(age#0)=2)]
+ >>> df3.groupBy().min('age', 'height').collect()
+ [Row(MIN(age#4)=2, MIN(height#5)=80)]
+ """
+
+ @df_varargs_api
+ def sum(self, *cols):
"""Compute the sum for each numeric columns for each
- group."""
+ group.
+
+ >>> df.groupBy().sum('age').collect()
+ [Row(SUM(age#0)=7)]
+ >>> df3.groupBy().sum('age', 'height').collect()
+ [Row(SUM(age#4)=7, SUM(height#5)=165)]
+ """
def _create_column_from_literal(literal):
@@ -945,6 +987,8 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob',
age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob',
height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 39aa550..d0e0906 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -158,6 +158,8 @@ def _test():
globs['sqlCtx'] = SQLContext(sc)
globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob',
age=5)]).toDF()
globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob',
height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
pyspark.sql.dataframe, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
index 7b7efbe..9eb0c13 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
@@ -88,12 +88,12 @@ private[sql] class DataFrameImpl protected[sql](
}
}
- protected[sql] def numericColumns: Seq[Expression] = {
+ protected[sql] def numericColumns(): Seq[Expression] = {
schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
}
}
-
+
override def toDF(colNames: String*): DataFrame = {
require(schema.size == colNames.size,
"The number of columns doesn't match.\n" +
http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 0868013..a5a677b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -23,6 +23,8 @@ import scala.collection.JavaConversions._
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.types.NumericType
+
/**
@@ -39,13 +41,30 @@ class GroupedData protected[sql](df: DataFrameImpl,
groupingExprs: Seq[Expressio
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs,
df.logicalPlan))
}
- private[this] def aggregateNumericColumns(f: Expression => Expression):
Seq[NamedExpression] = {
- df.numericColumns.map { c =>
+ private[this] def aggregateNumericColumns(colNames: String*)(f: Expression
=> Expression)
+ : Seq[NamedExpression] = {
+
+ val columnExprs = if (colNames.isEmpty) {
+ // No columns specified. Use all numeric columns.
+ df.numericColumns
+ } else {
+ // Make sure all specified columns are numeric
+ colNames.map { colName =>
+ val namedExpr = df.resolve(colName)
+ if (!namedExpr.dataType.isInstanceOf[NumericType]) {
+ throw new AnalysisException(
+ s""""$colName" is not a numeric column. """ +
+ "Aggregation function can only be performed on a numeric column.")
+ }
+ namedExpr
+ }
+ }
+ columnExprs.map { c =>
val a = f(c)
Alias(a, a.toString)()
}
}
-
+
private[this] def strToExpr(expr: String): (Expression => Expression) = {
expr.toLowerCase match {
case "avg" | "average" | "mean" => Average
@@ -152,30 +171,50 @@ class GroupedData protected[sql](df: DataFrameImpl,
groupingExprs: Seq[Expressio
/**
* Compute the average value for each numeric columns for each group. This
is an alias for `avg`.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the average values for
them.
*/
- def mean(): DataFrame = aggregateNumericColumns(Average)
-
+ @scala.annotation.varargs
+ def mean(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
+
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the max values for them.
*/
- def max(): DataFrame = aggregateNumericColumns(Max)
+ @scala.annotation.varargs
+ def max(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Max)
+ }
/**
* Compute the mean value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the mean values for them.
*/
- def avg(): DataFrame = aggregateNumericColumns(Average)
+ @scala.annotation.varargs
+ def avg(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
/**
* Compute the min value for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the min values for them.
*/
- def min(): DataFrame = aggregateNumericColumns(Min)
+ @scala.annotation.varargs
+ def min(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Min)
+ }
/**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the sum for them.
*/
- def sum(): DataFrame = aggregateNumericColumns(Sum)
+ @scala.annotation.varargs
+ def sum(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Sum)
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/5c78be7a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index f0cd436..524571d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -162,6 +162,18 @@ class DataFrameSuite extends QueryTest {
testData2.groupBy("a").agg(Map("b" -> "sum")),
Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
)
+
+ val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
+ .toDF("key", "value1", "value2", "rest")
+
+ checkAnswer(
+ df1.groupBy("key").min(),
+ df1.groupBy("key").min("value1", "value2").collect()
+ )
+ checkAnswer(
+ df1.groupBy("key").min("value2"),
+ Seq(Row("a", 0), Row("b", 4))
+ )
}
test("agg without groups") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]