Repository: spark
Updated Branches:
  refs/heads/branch-2.0 435d903d3 -> a7e8cfa64


[SPARK-15079] Support average/count/sum in Long/DoubleAccumulator

## What changes were proposed in this pull request?
This patch removes AverageAccumulator and adds the ability to compute average 
to LongAccumulator and DoubleAccumulator. The patch also improves documentation 
for the two accumulators.

## How was this patch tested?
Added unit tests for this.

Author: Reynold Xin <[email protected]>

Closes #12858 from rxin/SPARK-15079.

(cherry picked from commit bb9ab56b960153d374d7e8838f62a18e7e45481e)
Signed-off-by: Reynold Xin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a7e8cfa6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a7e8cfa6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a7e8cfa6

Branch: refs/heads/branch-2.0
Commit: a7e8cfa64de26be2e517e2eda237a9e8a58008c5
Parents: 435d903
Author: Reynold Xin <[email protected]>
Authored: Mon May 2 21:12:48 2016 -0700
Committer: Reynold Xin <[email protected]>
Committed: Mon May 2 21:13:00 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/Accumulator.scala    |  17 ---
 .../scala/org/apache/spark/AccumulatorV2.scala  | 137 ++++++++++++-------
 .../scala/org/apache/spark/SparkContext.scala   |  22 ---
 .../org/apache/spark/AccumulatorSuite.scala     |  17 +--
 .../apache/spark/util/AccumulatorV2Suite.scala  |  89 ++++++++++++
 5 files changed, 181 insertions(+), 101 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a7e8cfa6/core/src/main/scala/org/apache/spark/Accumulator.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/Accumulator.scala 
b/core/src/main/scala/org/apache/spark/Accumulator.scala
index e52d36b..2324504 100644
--- a/core/src/main/scala/org/apache/spark/Accumulator.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulator.scala
@@ -17,9 +17,6 @@
 
 package org.apache.spark
 
-import org.apache.spark.storage.{BlockId, BlockStatus}
-
-
 /**
  * A simpler value of [[Accumulable]] where the result type being accumulated 
is the same
  * as the types of elements being merged, i.e. variables that are only "added" 
to through an
@@ -117,18 +114,4 @@ object AccumulatorParam {
     def addInPlace(t1: String, t2: String): String = t2
     def zero(initialValue: String): String = ""
   }
-
-  // Note: this is expensive as it makes a copy of the list every time the 
caller adds an item.
-  // A better way to use this is to first accumulate the values yourself then 
them all at once.
-  @deprecated("use AccumulatorV2", "2.0.0")
-  private[spark] class ListAccumulatorParam[T] extends 
AccumulatorParam[Seq[T]] {
-    def addInPlace(t1: Seq[T], t2: Seq[T]): Seq[T] = t1 ++ t2
-    def zero(initialValue: Seq[T]): Seq[T] = Seq.empty[T]
-  }
-
-  // For the internal metric that records what blocks are updated in a 
particular task
-  @deprecated("use AccumulatorV2", "2.0.0")
-  private[spark] object UpdatedBlockStatusesAccumulatorParam
-    extends ListAccumulatorParam[(BlockId, BlockStatus)]
-
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a7e8cfa6/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala 
b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
index c65108a..a6c64fd 100644
--- a/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/AccumulatorV2.scala
@@ -257,23 +257,66 @@ private[spark] object AccumulatorContext {
 }
 
 
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 
64-bit integers.
+ *
+ * @since 2.0.0
+ */
 class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
   private[this] var _sum = 0L
+  private[this] var _count = 0L
 
-  override def isZero: Boolean = _sum == 0
+  /**
+   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+   * @since 2.0.0
+   */
+  override def isZero: Boolean = _count == 0L
 
   override def copyAndReset(): LongAccumulator = new LongAccumulator
 
-  override def add(v: jl.Long): Unit = _sum += v
+  /**
+   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+   * @since 2.0.0
+   */
+  override def add(v: jl.Long): Unit = {
+    _sum += v
+    _count += 1
+  }
+
+  /**
+   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+   * @since 2.0.0
+   */
+  def add(v: Long): Unit = {
+    _sum += v
+    _count += 1
+  }
 
-  def add(v: Long): Unit = _sum += v
+  /**
+   * Returns the number of elements added to the accumulator.
+   * @since 2.0.0
+   */
+  def count: Long = _count
 
+  /**
+   * Returns the sum of elements added to the accumulator.
+   * @since 2.0.0
+   */
   def sum: Long = _sum
 
+  /**
+   * Returns the average of elements added to the accumulator.
+   * @since 2.0.0
+   */
+  def avg: Double = _sum.toDouble / _count
+
   override def merge(other: AccumulatorV2[jl.Long, jl.Long]): Unit = other 
match {
-    case o: LongAccumulator => _sum += o.sum
-    case _ => throw new UnsupportedOperationException(
-      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
+    case o: LongAccumulator =>
+      _sum += o.sum
+      _count += o.count
+    case _ =>
+      throw new UnsupportedOperationException(
+        s"Cannot merge ${this.getClass.getName} with 
${other.getClass.getName}")
   }
 
   private[spark] def setValue(newValue: Long): Unit = _sum = newValue
@@ -282,66 +325,68 @@ class LongAccumulator extends AccumulatorV2[jl.Long, 
jl.Long] {
 }
 
 
+/**
+ * An [[AccumulatorV2 accumulator]] for computing sum, count, and averages for 
double precision
+ * floating numbers.
+ *
+ * @since 2.0.0
+ */
 class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
   private[this] var _sum = 0.0
-
-  override def isZero: Boolean = _sum == 0.0
-
-  override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
-
-  override def add(v: jl.Double): Unit = _sum += v
-
-  def add(v: Double): Unit = _sum += v
-
-  def sum: Double = _sum
-
-  override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other 
match {
-    case o: DoubleAccumulator => _sum += o.sum
-    case _ => throw new UnsupportedOperationException(
-      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
-  }
-
-  private[spark] def setValue(newValue: Double): Unit = _sum = newValue
-
-  override def localValue: jl.Double = _sum
-}
-
-
-class AverageAccumulator extends AccumulatorV2[jl.Double, jl.Double] {
-  private[this] var _sum = 0.0
   private[this] var _count = 0L
 
-  override def isZero: Boolean = _sum == 0.0 && _count == 0
+  override def isZero: Boolean = _count == 0L
 
-  override def copyAndReset(): AverageAccumulator = new AverageAccumulator
+  override def copyAndReset(): DoubleAccumulator = new DoubleAccumulator
 
+  /**
+   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+   * @since 2.0.0
+   */
   override def add(v: jl.Double): Unit = {
     _sum += v
     _count += 1
   }
 
-  def add(d: Double): Unit = {
-    _sum += d
+  /**
+   * Adds v to the accumulator, i.e. increment sum by v and count by 1.
+   * @since 2.0.0
+   */
+  def add(v: Double): Unit = {
+    _sum += v
     _count += 1
   }
 
+  /**
+   * Returns the number of elements added to the accumulator.
+   * @since 2.0.0
+   */
+  def count: Long = _count
+
+  /**
+   * Returns the sum of elements added to the accumulator.
+   * @since 2.0.0
+   */
+  def sum: Double = _sum
+
+  /**
+   * Returns the average of elements added to the accumulator.
+   * @since 2.0.0
+   */
+  def avg: Double = _sum / _count
+
   override def merge(other: AccumulatorV2[jl.Double, jl.Double]): Unit = other 
match {
-    case o: AverageAccumulator =>
+    case o: DoubleAccumulator =>
       _sum += o.sum
       _count += o.count
-    case _ => throw new UnsupportedOperationException(
-      s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}")
-  }
-
-  override def localValue: jl.Double = if (_count == 0) {
-    Double.NaN
-  } else {
-    _sum / _count
+    case _ =>
+      throw new UnsupportedOperationException(
+        s"Cannot merge ${this.getClass.getName} with 
${other.getClass.getName}")
   }
 
-  def sum: Double = _sum
+  private[spark] def setValue(newValue: Double): Unit = _sum = newValue
 
-  def count: Long = _count
+  override def localValue: jl.Double = _sum
 }
 
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a7e8cfa6/core/src/main/scala/org/apache/spark/SparkContext.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala 
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 58618b4..e391599 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1341,28 +1341,6 @@ class SparkContext(config: SparkConf) extends Logging 
with ExecutorAllocationCli
   }
 
   /**
-   * Create and register an average accumulator, which accumulates double 
inputs by recording the
-   * total sum and total count, and produce the output by sum / total.  Note 
that Double.NaN will be
-   * returned if no input is added.
-   */
-  def averageAccumulator: AverageAccumulator = {
-    val acc = new AverageAccumulator
-    register(acc)
-    acc
-  }
-
-  /**
-   * Create and register an average accumulator, which accumulates double 
inputs by recording the
-   * total sum and total count, and produce the output by sum / total.  Note 
that Double.NaN will be
-   * returned if no input is added.
-   */
-  def averageAccumulator(name: String): AverageAccumulator = {
-    val acc = new AverageAccumulator
-    register(acc, name)
-    acc
-  }
-
-  /**
    * Create and register a list accumulator, which starts with empty list and 
accumulates inputs
    * by adding them into the inner list.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/a7e8cfa6/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala 
b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 09eb9c1..0020096 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -28,7 +28,7 @@ import scala.util.control.NonFatal
 import org.scalatest.Matchers
 import org.scalatest.exceptions.TestFailedException
 
-import org.apache.spark.AccumulatorParam.{ListAccumulatorParam, 
StringAccumulatorParam}
+import org.apache.spark.AccumulatorParam.StringAccumulatorParam
 import org.apache.spark.scheduler._
 import org.apache.spark.serializer.JavaSerializer
 
@@ -234,21 +234,6 @@ class AccumulatorSuite extends SparkFunSuite with Matchers 
with LocalSparkContex
     acc.merge("kindness")
     assert(acc.value === "kindness")
   }
-
-  test("list accumulator param") {
-    val acc = new Accumulator(Seq.empty[Int], new ListAccumulatorParam[Int], 
Some("numbers"))
-    assert(acc.value === Seq.empty[Int])
-    acc.add(Seq(1, 2))
-    assert(acc.value === Seq(1, 2))
-    acc += Seq(3, 4)
-    assert(acc.value === Seq(1, 2, 3, 4))
-    acc ++= Seq(5, 6)
-    assert(acc.value === Seq(1, 2, 3, 4, 5, 6))
-    acc.merge(Seq(7, 8))
-    assert(acc.value === Seq(1, 2, 3, 4, 5, 6, 7, 8))
-    acc.setValue(Seq(9, 10))
-    assert(acc.value === Seq(9, 10))
-  }
 }
 
 private[spark] object AccumulatorSuite {

http://git-wip-us.apache.org/repos/asf/spark/blob/a7e8cfa6/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala 
b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
new file mode 100644
index 0000000..41cdd02
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.{DoubleAccumulator, LongAccumulator, SparkFunSuite}
+
+class AccumulatorV2Suite extends SparkFunSuite {
+
+  test("LongAccumulator add/avg/sum/count/isZero") {
+    val acc = new LongAccumulator
+    assert(acc.isZero)
+    assert(acc.count == 0)
+    assert(acc.sum == 0)
+    assert(acc.avg.isNaN)
+
+    acc.add(0)
+    assert(!acc.isZero)
+    assert(acc.count == 1)
+    assert(acc.sum == 0)
+    assert(acc.avg == 0.0)
+
+    acc.add(1)
+    assert(acc.count == 2)
+    assert(acc.sum == 1)
+    assert(acc.avg == 0.5)
+
+    // Also test add using non-specialized add function
+    acc.add(new java.lang.Long(2))
+    assert(acc.count == 3)
+    assert(acc.sum == 3)
+    assert(acc.avg == 1.0)
+
+    // Test merging
+    val acc2 = new LongAccumulator
+    acc2.add(2)
+    acc.merge(acc2)
+    assert(acc.count == 4)
+    assert(acc.sum == 5)
+    assert(acc.avg == 1.25)
+  }
+
+  test("DoubleAccumulator add/avg/sum/count/isZero") {
+    val acc = new DoubleAccumulator
+    assert(acc.isZero)
+    assert(acc.count == 0)
+    assert(acc.sum == 0.0)
+    assert(acc.avg.isNaN)
+
+    acc.add(0.0)
+    assert(!acc.isZero)
+    assert(acc.count == 1)
+    assert(acc.sum == 0.0)
+    assert(acc.avg == 0.0)
+
+    acc.add(1.0)
+    assert(acc.count == 2)
+    assert(acc.sum == 1.0)
+    assert(acc.avg == 0.5)
+
+    // Also test add using non-specialized add function
+    acc.add(new java.lang.Double(2.0))
+    assert(acc.count == 3)
+    assert(acc.sum == 3.0)
+    assert(acc.avg == 1.0)
+
+    // Test merging
+    val acc2 = new DoubleAccumulator
+    acc2.add(2.0)
+    acc.merge(acc2)
+    assert(acc.count == 4)
+    assert(acc.sum == 5.0)
+    assert(acc.avg == 1.25)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to