This is an automated email from the ASF dual-hosted git repository.
aherbert pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/commons-statistics.git
The following commit(s) were added to refs/heads/master by this push:
new 6eb4678 Add benchmark for moment-based statistics
6eb4678 is described below
commit 6eb46788cbbc8a6c52f0c7e3070b0d5224152494
Author: aherbert <[email protected]>
AuthorDate: Mon Dec 11 16:57:52 2023 +0000
Add benchmark for moment-based statistics
The benchmark tests algorithms computing the mean.
---
commons-statistics-examples/examples-jmh/pom.xml | 5 +
.../jmh/descriptive/MomentPerformance.java | 404 +++++++++++++++++++++
.../jmh/descriptive/MomentPerformanceTest.java | 49 +++
commons-statistics-examples/pom.xml | 5 +
4 files changed, 463 insertions(+)
diff --git a/commons-statistics-examples/examples-jmh/pom.xml
b/commons-statistics-examples/examples-jmh/pom.xml
index 008edaf..f2e9386 100644
--- a/commons-statistics-examples/examples-jmh/pom.xml
+++ b/commons-statistics-examples/examples-jmh/pom.xml
@@ -31,6 +31,11 @@
Code in this module is not part of the public API.</description>
<dependencies>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-statistics-descriptive</artifactId>
+ </dependency>
+
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-statistics-distribution</artifactId>
diff --git
a/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/descriptive/MomentPerformance.java
b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/descriptive/MomentPerformance.java
new file mode 100644
index 0000000..121b421
--- /dev/null
+++
b/commons-statistics-examples/examples-jmh/src/main/java/org/apache/commons/statistics/examples/jmh/descriptive/MomentPerformance.java
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.statistics.examples.jmh.descriptive;
+
+import java.util.Arrays;
+import java.util.concurrent.TimeUnit;
+import java.util.function.DoubleConsumer;
+import java.util.function.DoubleSupplier;
+import java.util.function.Function;
+import java.util.function.Supplier;
+import org.apache.commons.rng.simple.RandomSource;
+import org.apache.commons.statistics.descriptive.Mean;
+import org.openjdk.jmh.annotations.Benchmark;
+import org.openjdk.jmh.annotations.BenchmarkMode;
+import org.openjdk.jmh.annotations.Fork;
+import org.openjdk.jmh.annotations.Level;
+import org.openjdk.jmh.annotations.Measurement;
+import org.openjdk.jmh.annotations.Mode;
+import org.openjdk.jmh.annotations.OutputTimeUnit;
+import org.openjdk.jmh.annotations.Param;
+import org.openjdk.jmh.annotations.Scope;
+import org.openjdk.jmh.annotations.Setup;
+import org.openjdk.jmh.annotations.State;
+import org.openjdk.jmh.annotations.Warmup;
+
+/**
+ * Executes a benchmark of the moment-based statistics.
+ */
+@BenchmarkMode(Mode.AverageTime)
+@OutputTimeUnit(TimeUnit.NANOSECONDS)
+@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
+@State(Scope.Benchmark)
+@Fork(value = 1, jvmArgs = {"-server", "-Xms512M", "-Xmx512M"})
+public class MomentPerformance {
+ /** Commons Statistics Mean implementation. */
+ private static final String MEAN = "Mean";
+ /** Summation mean implementation. */
+ private static final String SUM_MEAN = "SumMean";
+ /** Extended precision summation mean implementation. */
+ private static final String EXTENDED_SUM_MEAN = "ExtendedSumMean";
+ /** Rolling mean implementation. */
+ private static final String ROLLING_MEAN = "RollingMean";
+ /** Safe rolling mean implementation. */
+ private static final String SAFE_ROLLING_MEAN = "SafeRollingMean";
+ /** Inline rolling mean implementation for array-based creation. */
+ private static final String INLINE_ROLLING_MEAN = "InlineRollingMean";
+
+ /**
+ * Source of {@code double} array data.
+ */
+ @State(Scope.Benchmark)
+ public static class DataSource {
+ /** Data length. */
+ @Param({"1", "10", "1000"})
+ private int length;
+
+ /** Data. */
+ private double[] data;
+
+ /**
+ * @return the data
+ */
+ public double[] getData() {
+ return data;
+ }
+
+ /**
+ * Create the data.
+ */
+ @Setup(Level.Iteration)
+ public void setup() {
+ // Data will be randomized per iteration
+ data =
RandomSource.XO_RO_SHI_RO_128_PP.create().doubles(length).toArray();
+ }
+ }
+
+ /**
+ * Source of a {@link DoubleConsumer} action.
+ */
+ @State(Scope.Benchmark)
+ public static class ActionSource {
+ /** Name of the source. */
+ @Param({MEAN, ROLLING_MEAN, SAFE_ROLLING_MEAN, SUM_MEAN,
EXTENDED_SUM_MEAN})
+ private String name;
+
+ /** The action. */
+ private Supplier<DoubleConsumer> action;
+
+ /**
+ * @return the action
+ */
+ public DoubleConsumer getAction() {
+ return action.get();
+ }
+
+ /**
+ * Create the data.
+ */
+ @Setup(Level.Iteration)
+ public void setup() {
+ if (MEAN.equals(name)) {
+ action = Mean::create;
+ } else if (ROLLING_MEAN.equals(name)) {
+ action = RollingFirstMoment::new;
+ } else if (SAFE_ROLLING_MEAN.equals(name)) {
+ action = SafeRollingFirstMoment::new;
+ } else if (SUM_MEAN.equals(name)) {
+ action = SumFirstMoment::new;
+ } else if (EXTENDED_SUM_MEAN.equals(name)) {
+ action = ExtendedSumFirstMoment::new;
+ } else {
+ throw new IllegalStateException("Unknown action: " + name);
+ }
+ }
+ }
+
+ /**
+ * Source of a {@link Function} for a {@code double[]}.
+ */
+ @State(Scope.Benchmark)
+ public static class FunctionSource {
+ /** Name of the source. */
+ @Param({MEAN, ROLLING_MEAN, SAFE_ROLLING_MEAN,
+ // Same speed as the ROLLING_MEAN, i.e. the DoubleConsumer is not
an overhead
+ //INLINE_ROLLING_MEAN
+ })
+ private String name;
+
+ /** The action. */
+ private Function<double[], Object> function;
+
+ /**
+ * @return the function
+ */
+ public Function<double[], Object> getFunction() {
+ return function;
+ }
+
+ /**
+ * Create the data.
+ */
+ @Setup(Level.Iteration)
+ public void setup() {
+ if (MEAN.equals(name)) {
+ function = Mean::of;
+ } else if (ROLLING_MEAN.equals(name)) {
+ function = MomentPerformance::arrayRollingFirstMoment;
+ } else if (SAFE_ROLLING_MEAN.equals(name)) {
+ function = MomentPerformance::arraySafeRollingFirstMoment;
+ } else if (INLINE_ROLLING_MEAN.equals(name)) {
+ function = MomentPerformance::arrayInlineRollingFirstMoment;
+ } else {
+ throw new IllegalStateException("Unknown function: " + name);
+ }
+ }
+ }
+
+ /**
+ * A rolling first raw moment of {@code double} data.
+ */
+ static class RollingFirstMoment implements DoubleConsumer, DoubleSupplier {
+ /** Count of values that have been added. */
+ private long n;
+
+ /** First moment of values that have been added. */
+ private double m1;
+
+ @Override
+ public void accept(double value) {
+ m1 += (value - m1) / ++n;
+ }
+
+ @Override
+ public double getAsDouble() {
+ // NaN for all non-finite results
+ return Double.isFinite(m1) && n != 0 ? m1 : Double.NaN;
+ }
+ }
+
+ /**
+ * A rolling first raw moment of {@code double} data safe to overflow of
any finite
+ * values (e.g. [MAX_VALUE, -MAX_VALUE]).
+ */
+ static class SafeRollingFirstMoment implements DoubleConsumer,
DoubleSupplier {
+ /** Count of values that have been added. */
+ private long n;
+
+ /** First moment of values that have been added. */
+ private double m1;
+
+ @Override
+ public void accept(double value) {
+ m1 += ((value * 0.5 - m1 * 0.5) / ++n) * 2;
+ }
+
+ @Override
+ public double getAsDouble() {
+ // NaN for all non-finite results
+ return Double.isFinite(m1) && n != 0 ? m1 : Double.NaN;
+ }
+ }
+
+ /**
+ * A mean using a sum.
+ */
+ static class SumFirstMoment implements DoubleConsumer, DoubleSupplier {
+ /** Count of values that have been added. */
+ private long n;
+
+ /** Sum of values that have been added. */
+ private double sum;
+
+ @Override
+ public void accept(double value) {
+ n++;
+ sum += value;
+ }
+
+ @Override
+ public double getAsDouble() {
+ return sum / n;
+ }
+ }
+
+ /**
+ * A mean using an extended precision sum.
+ *
+ * <p>This type of summation is used in DoubleStream to compute the sum
and derive the
+ * mean. This method acts as a proxy to compare the speed of the rolling
algorithm to
+ * collect a stream verses a high-precision sum using
+ * {@link java.util.stream.DoubleStream#sum()}.
+ */
+ static class ExtendedSumFirstMoment implements DoubleConsumer,
DoubleSupplier {
+ /** Count of values that have been added. */
+ private long n;
+
+ /** Sum of values that have been added. */
+ private double sum;
+ /** A running compensation for lost low-order bits. */
+ private double c;
+
+ @Override
+ public void accept(double value) {
+ n++;
+ // Kahan summation
+ // https://en.wikipedia.org/wiki/Kahan_summation_algorithm
+ final double y = value - c;
+ final double t = sum + y;
+ c = (t - sum) - y;
+ sum = t;
+ }
+
+ @Override
+ public double getAsDouble() {
+ return sum / n;
+ }
+ }
+
+ /**
+ * Apply the action to each value.
+ *
+ * @param <T> the action type
+ * @param action Action.
+ * @param values Values.
+ * @return the action
+ */
+ static <T extends DoubleConsumer> T forEach(T action, double[] values) {
+ for (final double x : values) {
+ action.accept(x);
+ }
+ return action;
+ }
+
+ /**
+ * Correct the mean using a second pass over the data.
+ *
+ * @param data Data.
+ * @param xbar Current mean.
+ * @return the mean
+ */
+ private static double correctMean(double[] data, double xbar) {
+ double correction = 0;
+ for (final double x : data) {
+ correction += x - xbar;
+ }
+ // Note: Correction may be infinite
+ if (Double.isFinite(correction)) {
+ return xbar + correction / data.length;
+ }
+ return xbar;
+ }
+
+ /**
+ * Create the two-pass mean using a rolling first moment.
+ *
+ * @param data Data.
+ * @return the statistic
+ */
+ static double arrayRollingFirstMoment(double[] data) {
+ final RollingFirstMoment m1 = new RollingFirstMoment();
+ for (final double x : data) {
+ m1.accept(x);
+ }
+ final double xbar = m1.getAsDouble();
+ if (!Double.isFinite(xbar)) {
+ // Note: Also occurs when the input is empty
+ return xbar;
+ }
+ return correctMean(data, xbar);
+ }
+
+ /**
+ * Create the two-pass mean using a rolling first moment
+ * safe to overflow.
+ *
+ * @param data Data.
+ * @return the statistic
+ */
+ static double arraySafeRollingFirstMoment(double[] data) {
+ final SafeRollingFirstMoment m1 = new SafeRollingFirstMoment();
+ for (final double x : data) {
+ m1.accept(x);
+ }
+ final double xbar = m1.getAsDouble();
+ if (!Double.isFinite(xbar)) {
+ // Note: Also occurs when the input is empty
+ return xbar;
+ }
+ return correctMean(data, xbar);
+ }
+
+ /**
+ * Create the two-pass mean using a rolling first moment inline.
+ *
+ * <p>Note: This method is effectively the same as {@link
#arrayRollingFirstMoment(double[])}
+ * and timing tests show there is no overhead to using an object to
aggregate the first moment,
+ * i.e. this is not faster.
+ *
+ * @param data Data.
+ * @return the statistic
+ */
+ static double arrayInlineRollingFirstMoment(double[] data) {
+ double m1 = 0;
+ int n = 0;
+ for (final double x : data) {
+ m1 += (x - m1) / ++n;
+ }
+ if (!Double.isFinite(m1) || n == 0) {
+ return Double.NaN;
+ }
+ return correctMean(data, m1);
+ }
+
+ /**
+ * Create the mean from a stream of {@code double} values.
+ *
+ * @param source Source of the data.
+ * @return the mean
+ */
+ @Benchmark
+ public Object streamMean(DataSource source) {
+ return Arrays.stream(source.getData()).average();
+ }
+
+ /**
+ * Create the statistic using a consumer of {@code double} values.
+ *
+ * @param action Source of the data action.
+ * @param source Source of the data.
+ * @return the statistic
+ */
+ @Benchmark
+ public Object forEachStatistic(ActionSource action, DataSource source) {
+ return forEach(action.getAction(), source.getData());
+ }
+
+ /**
+ * Create the statistic using a {@code double[]} function.
+ *
+ * @param function Source of the function.
+ * @param source Source of the data.
+ * @return the statistic
+ */
+ @Benchmark
+ public Object arrayStatistic(FunctionSource function, DataSource source) {
+ return function.getFunction().apply(source.getData());
+ }
+}
diff --git
a/commons-statistics-examples/examples-jmh/src/test/java/org/apache/commons/statistics/examples/jmh/descriptive/MomentPerformanceTest.java
b/commons-statistics-examples/examples-jmh/src/test/java/org/apache/commons/statistics/examples/jmh/descriptive/MomentPerformanceTest.java
new file mode 100644
index 0000000..733f132
--- /dev/null
+++
b/commons-statistics-examples/examples-jmh/src/test/java/org/apache/commons/statistics/examples/jmh/descriptive/MomentPerformanceTest.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.commons.statistics.examples.jmh.descriptive;
+
+import java.util.concurrent.ThreadLocalRandom;
+import org.apache.commons.statistics.descriptive.Mean;
+import
org.apache.commons.statistics.examples.jmh.descriptive.MomentPerformance.ExtendedSumFirstMoment;
+import
org.apache.commons.statistics.examples.jmh.descriptive.MomentPerformance.RollingFirstMoment;
+import
org.apache.commons.statistics.examples.jmh.descriptive.MomentPerformance.SafeRollingFirstMoment;
+import
org.apache.commons.statistics.examples.jmh.descriptive.MomentPerformance.SumFirstMoment;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.ValueSource;
+
+/**
+ * Executes tests for {@link MomentPerformance}.
+ */
+class MomentPerformanceTest {
+ @ParameterizedTest
+ @ValueSource(ints = {0, 1, 5, 10})
+ void testFirstMoment(int n) {
+ final double[] values =
ThreadLocalRandom.current().doubles(n).toArray();
+ // Expected should be 0.5 as n -> inf
+ final double expected = Mean.of(values).getAsDouble();
+ final double tolerance = n <= 1 ? 0 : expected * 1e-14;
+ Assertions.assertEquals(expected, MomentPerformance.forEach(new
RollingFirstMoment(), values).getAsDouble(), tolerance);
+ Assertions.assertEquals(expected, MomentPerformance.forEach(new
SafeRollingFirstMoment(), values).getAsDouble(), tolerance);
+ Assertions.assertEquals(expected, MomentPerformance.forEach(new
SumFirstMoment(), values).getAsDouble(), tolerance);
+ Assertions.assertEquals(expected, MomentPerformance.forEach(new
ExtendedSumFirstMoment(), values).getAsDouble(), tolerance);
+ Assertions.assertEquals(expected,
MomentPerformance.arrayRollingFirstMoment(values), tolerance);
+ Assertions.assertEquals(expected,
MomentPerformance.arraySafeRollingFirstMoment(values), tolerance);
+ Assertions.assertEquals(expected,
MomentPerformance.arrayInlineRollingFirstMoment(values), tolerance);
+ }
+}
diff --git a/commons-statistics-examples/pom.xml
b/commons-statistics-examples/pom.xml
index a4708df..fecd678 100644
--- a/commons-statistics-examples/pom.xml
+++ b/commons-statistics-examples/pom.xml
@@ -54,6 +54,11 @@
<dependencyManagement>
<dependencies>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-statistics-descriptive</artifactId>
+ <version>1.1-SNAPSHOT</version>
+ </dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-statistics-distribution</artifactId>