Repository: spark
Updated Branches:
refs/heads/master 065952961 -> c3576ffcd
[SQL] Minor: Introduce SchemaRDD#aggregate() for simple aggregations
```scala
rdd.aggregate(Sum('val))
```
is just shorthand for
```scala
rdd.groupBy()(Sum('val))
```
but seems be more natural than doing a groupBy with no grouping expressions
when you really just want an aggregation over all rows.
Did not add a JavaSchemaRDD or Python API, as these seem to be lacking several
other methods like groupBy() already -- leaving that cleanup for future patches.
Author: Aaron Davidson <[email protected]>
Closes #874 from aarondav/schemardd and squashes the following commits:
e9e68ee [Aaron Davidson] Add comment
db6afe2 [Aaron Davidson] Introduce SchemaRDD#aggregate() for simple aggregations
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c3576ffc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c3576ffc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c3576ffc
Branch: refs/heads/master
Commit: c3576ffcd7910e38928f233a824dd9e037cde05f
Parents: 0659529
Author: Aaron Davidson <[email protected]>
Authored: Sun May 25 18:37:44 2014 -0700
Committer: Reynold Xin <[email protected]>
Committed: Sun May 25 18:37:44 2014 -0700
----------------------------------------------------------------------
.../scala/org/apache/spark/sql/SchemaRDD.scala | 18 ++++++++++++++++--
.../org/apache/spark/sql/DslQuerySuite.scala | 8 ++++++++
2 files changed, 24 insertions(+), 2 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/c3576ffc/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 9883ebc..e855f36 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -59,7 +59,7 @@ import java.util.{Map => JMap}
* // Importing the SQL context gives access to all the SQL functions and
implicit conversions.
* import sqlContext._
*
- * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i")))
+ * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
* // Any RDD containing case classes can be registered as a table. The
schema of the table is
* // automatically inferred using scala reflection.
* rdd.registerAsTable("records")
@@ -205,6 +205,20 @@ class SchemaRDD(
}
/**
+ * Performs an aggregation over all Rows in this RDD.
+ * This is equivalent to a groupBy with no grouping expressions.
+ *
+ * {{{
+ * schemaRDD.aggregate(Sum('sales) as 'totalSales)
+ * }}}
+ *
+ * @group Query
+ */
+ def aggregate(aggregateExprs: Expression*): SchemaRDD = {
+ groupBy()(aggregateExprs: _*)
+ }
+
+ /**
* Applies a qualifier to the attributes of this relation. Can be used to
disambiguate attributes
* with the same name, for example, when performing self-joins.
*
@@ -281,7 +295,7 @@ class SchemaRDD(
* supports features such as filter pushdown.
*/
@Experimental
- override def count(): Long =
groupBy()(Count(Literal(1))).collect().head.getLong(0)
+ override def count(): Long =
aggregate(Count(Literal(1))).collect().head.getLong(0)
/**
* :: Experimental ::
http://git-wip-us.apache.org/repos/asf/spark/blob/c3576ffc/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 94ba13b..692569a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest {
testData2.groupBy('a)('a, Sum('b)),
Seq((1,3),(2,3),(3,3))
)
+ checkAnswer(
+ testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
+ 9
+ )
+ checkAnswer(
+ testData2.aggregate(Sum('b)),
+ 9
+ )
}
test("select *") {