This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push:
new bde545f91e4 Add support for STDDEV_POP and STDDEV_SAMP in VarianceFn
(#38871)
bde545f91e4 is described below
commit bde545f91e40e8f7a65e7d092918f1323769cc24
Author: Danny McCormick <[email protected]>
AuthorDate: Mon Jun 15 15:46:09 2026 -0400
Add support for STDDEV_POP and STDDEV_SAMP in VarianceFn (#38871)
* Add support for STDDEV_POP and STDDEV_SAMP in VarianceFn
* Add DSL integration tests for STDDEV_POP and STDDEV_SAMP
* Address review feedback: make fields final and add null check
* Address review feedback: handle numerical instability and overflow in
stddev
* Address review feedback: return infinity on standard deviation overflow
instead of throwing exception
---
.../impl/transform/BeamBuiltinAggregations.java | 2 +
.../sql/impl/transform/agg/VarianceFn.java | 38 ++++++++++++++++---
.../sql/BeamSqlDslAggregationVarianceTest.java | 43 +++++++++++++++++++++-
.../sql/impl/transform/agg/VarianceFnTest.java | 25 ++++++++++++-
4 files changed, 99 insertions(+), 9 deletions(-)
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
index 3fc299bd5a3..2800edfbb99 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/BeamBuiltinAggregations.java
@@ -83,6 +83,8 @@ public class BeamBuiltinAggregations {
typeName -> new
DropNullFn(BeamBuiltinAggregations.createBitAnd(typeName)))
.put("VAR_POP", t -> VarianceFn.newPopulation(t.getTypeName()))
.put("VAR_SAMP", t -> VarianceFn.newSample(t.getTypeName()))
+ .put("STDDEV_POP", t ->
VarianceFn.newPopulationStddev(t.getTypeName()))
+ .put("STDDEV_SAMP", t ->
VarianceFn.newSampleStddev(t.getTypeName()))
.put("COVAR_POP", t ->
CovarianceFn.newPopulation(t.getTypeName()))
.put("COVAR_SAMP", t -> CovarianceFn.newSample(t.getTypeName()))
.put("COUNTIF", typeName -> CountIf.combineFn())
diff --git
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
index dd2cd3b2095..906bac7add5 100644
---
a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
+++
b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFn.java
@@ -75,8 +75,12 @@ public class VarianceFn<T extends Number> extends
Combine.CombineFn<T, VarianceA
private static final boolean SAMPLE = true;
private static final boolean POP = false;
- private boolean isSample; // flag to determine return value should be
Variance Pop or Sample
- private SerializableFunction<BigDecimal, T> decimalConverter;
+ private final boolean isSample; // flag to determine return value should be
Variance Pop or Sample
+ // When true, extractOutput returns the square root of the variance (i.e.
standard deviation).
+ // Beam's enumerable bridge cannot translate a SQRT call layered on top of a
window VAR_SAMP, so
+ // STDDEV_SAMP / STDDEV_POP are computed end-to-end inside this combiner
instead.
+ private final boolean isStddev;
+ private final SerializableFunction<BigDecimal, T> decimalConverter;
public static VarianceFn newPopulation(Schema.TypeName typeName) {
return newPopulation(BigDecimalConverter.forSqlType(typeName));
@@ -85,7 +89,7 @@ public class VarianceFn<T extends Number> extends
Combine.CombineFn<T, VarianceA
public static <V extends Number> VarianceFn newPopulation(
SerializableFunction<BigDecimal, V> decimalConverter) {
- return new VarianceFn<>(POP, decimalConverter);
+ return new VarianceFn<>(POP, false, decimalConverter);
}
public static VarianceFn newSample(Schema.TypeName typeName) {
@@ -95,11 +99,21 @@ public class VarianceFn<T extends Number> extends
Combine.CombineFn<T, VarianceA
public static <V extends Number> VarianceFn newSample(
SerializableFunction<BigDecimal, V> decimalConverter) {
- return new VarianceFn<>(SAMPLE, decimalConverter);
+ return new VarianceFn<>(SAMPLE, false, decimalConverter);
}
- private VarianceFn(boolean isSample, SerializableFunction<BigDecimal, T>
decimalConverter) {
+ public static VarianceFn newSampleStddev(Schema.TypeName typeName) {
+ return new VarianceFn<>(SAMPLE, true,
BigDecimalConverter.forSqlType(typeName));
+ }
+
+ public static VarianceFn newPopulationStddev(Schema.TypeName typeName) {
+ return new VarianceFn<>(POP, true,
BigDecimalConverter.forSqlType(typeName));
+ }
+
+ private VarianceFn(
+ boolean isSample, boolean isStddev, SerializableFunction<BigDecimal, T>
decimalConverter) {
this.isSample = isSample;
+ this.isStddev = isStddev;
this.decimalConverter = decimalConverter;
}
@@ -133,7 +147,19 @@ public class VarianceFn<T extends Number> extends
Combine.CombineFn<T, VarianceA
@Override
public T extractOutput(VarianceAccumulator accumulator) {
- return decimalConverter.apply(getVariance(accumulator));
+ BigDecimal result = getVariance(accumulator);
+ if (result != null && isStddev) {
+ double doubleVal = result.doubleValue();
+ if (doubleVal < 0.0) {
+ doubleVal = 0.0; // Clamp negative variance due to numerical
instability
+ }
+ double sqrtVal = Math.sqrt(doubleVal);
+ if (Double.isInfinite(sqrtVal)) {
+ return decimalConverter.apply(result.sqrt(MATH_CTX));
+ }
+ result = BigDecimal.valueOf(sqrtVal);
+ }
+ return decimalConverter.apply(result);
}
private BigDecimal getVariance(VarianceAccumulator variance) {
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java
index 808b27aaac4..e2c548acf71 100644
---
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/BeamSqlDslAggregationVarianceTest.java
@@ -30,7 +30,10 @@ import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
-/** Integration tests for {@code VAR_POP} and {@code VAR_SAMP}. */
+/**
+ * Integration tests for {@code VAR_POP}, {@code VAR_SAMP}, {@code STDDEV_POP}
and {@code
+ * STDDEV_SAMP}.
+ */
public class BeamSqlDslAggregationVarianceTest {
private static final double PRECISION = 1e-7;
@@ -94,4 +97,42 @@ public class BeamSqlDslAggregationVarianceTest {
pipeline.run().waitUntilFinish();
}
+
+ @Test
+ public void testPopulationStddevDouble() {
+ String sql = "SELECT STDDEV_POP(f_double) FROM PCOLLECTION GROUP BY
f_int2";
+
+ PAssert.that(boundedInput.apply(SqlTransform.query(sql)))
+ .satisfies(matchesScalar(5.138887357, PRECISION));
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testPopulationStddevInt() {
+ String sql = "SELECT STDDEV_POP(f_int) FROM PCOLLECTION GROUP BY f_int2";
+
+
PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(5));
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testSampleStddevDouble() {
+ String sql = "SELECT STDDEV_SAMP(f_double) FROM PCOLLECTION GROUP BY
f_int2";
+
+ PAssert.that(boundedInput.apply(SqlTransform.query(sql)))
+ .satisfies(matchesScalar(5.550632739, PRECISION));
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testSampleStddevInt() {
+ String sql = "SELECT STDDEV_SAMP(f_int) FROM PCOLLECTION GROUP BY f_int2";
+
+
PAssert.that(boundedInput.apply(SqlTransform.query(sql))).satisfies(matchesScalar(5));
+
+ pipeline.run().waitUntilFinish();
+ }
}
diff --git
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java
index f7a8ad1fa06..0671a3caaa6 100644
---
a/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java
+++
b/sdks/java/extensions/sql/src/test/java/org/apache/beam/sdk/extensions/sql/impl/transform/agg/VarianceFnTest.java
@@ -26,6 +26,7 @@ import java.math.BigDecimal;
import java.util.Arrays;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.VarIntCoder;
+import org.apache.beam.sdk.schemas.Schema;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -51,18 +52,38 @@ public class VarianceFnTest {
VarianceFn.newSample(BigDecimal::intValue),
newVarianceAccumulator(FIFTEEN, FOUR, ZERO),
5
+ },
+ {
+ VarianceFn.newPopulationStddev(Schema.TypeName.INT32),
+ newVarianceAccumulator(new BigDecimal(36), new BigDecimal(4),
ZERO),
+ 3
+ },
+ {
+ VarianceFn.newSampleStddev(Schema.TypeName.INT32),
+ newVarianceAccumulator(new BigDecimal(36), new BigDecimal(5),
ZERO),
+ 3
+ },
+ {
+ VarianceFn.newPopulationStddev(Schema.TypeName.DOUBLE),
+ newVarianceAccumulator(new BigDecimal("1e700"), BigDecimal.ONE,
ZERO),
+ Double.POSITIVE_INFINITY
+ },
+ {
+ VarianceFn.newPopulationStddev(Schema.TypeName.FLOAT),
+ newVarianceAccumulator(new BigDecimal("1e700"), BigDecimal.ONE,
ZERO),
+ Float.POSITIVE_INFINITY
}
});
}
private VarianceFn varianceFn;
private VarianceAccumulator testAccumulatorInput;
- private int expectedExtractedResult;
+ private Object expectedExtractedResult;
public VarianceFnTest(
VarianceFn varianceFn,
VarianceAccumulator testAccumulatorInput,
- int expectedExtractedResult) {
+ Object expectedExtractedResult) {
this.varianceFn = varianceFn;
this.testAccumulatorInput = testAccumulatorInput;