This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new de0d7fb [SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation de0d7fb is described below commit de0d7fbb4f010bec8e457d0dc00b5618e7a43750 Author: Angerszhuuuu <angers....@gmail.com> AuthorDate: Wed Oct 27 19:47:17 2021 +0800 [SPARK-16280][SPARK-37082][SQL] Implements histogram_numeric aggregation function which supports partial aggregation ### What changes were proposed in this pull request? This PR implements aggregation function `histogram_numeric`. Function `histogram_numeric` returns an approximate histogram of a numerical column using a user-specified number of bins. For example, the histogram of column `col` when split to 3 bins. Syntax: #### an approximate histogram of a numerical column using a user-specified number of bins. histogram_numebric(col, nBins) ###### Returns an approximate histogram of a column `col` into 3 bins. SELECT histogram_numebric(col, 3) FROM table ##### Returns an approximate histogram of a column `col` into 5 bins. SELECT histogram_numebric(col, 5) FROM table ### Why are the changes needed? ### Does this PR introduce _any_ user-facing change? No change from user side ### How was this patch tested? Added UT Closes #34380 from AngersZhuuuu/SPARK-37082. Authored-by: Angerszhuuuu <angers....@gmail.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../apache/spark/sql/util/NumericHistogram.java | 286 +++++++++++++++++++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../expressions/aggregate/HistogramNumeric.scala | 207 +++++++++++++++ .../aggregate/HistogramNumericSuite.scala | 166 ++++++++++++ .../sql-functions/sql-expression-schema.md | 3 +- .../test/resources/sql-tests/inputs/group-by.sql | 12 + .../resources/sql-tests/results/group-by.sql.out | 20 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 4 +- 9 files changed, 695 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java new file mode 100644 index 0000000..987c18e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/NumericHistogram.java @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Random; + + +/** + * A generic, re-usable histogram class that supports partial aggregations. + * The algorithm is a heuristic adapted from the following paper: + * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm", + * J. Machine Learning Research 11 (2010), pp. 849--872. Although there are no approximation + * guarantees, it appears to work well with adequate data and a large (e.g., 20-80) number + * of histogram bins. + * + * Adapted from Hive's NumericHistogram. Can refer to + * https://github.com/apache/hive/blob/master/ql/src/ + * java/org/apache/hadoop/hive/ql/udf/generic/NumericHistogram.java + * + * Differences: + * 1. Declaring [[Coord]] and it's variables as public types for + * easy access in the HistogramNumeric class. + * 2. Add method [[getNumBins()]] for serialize [[NumericHistogram]] + * in [[NumericHistogramSerializer]]. + * 3. Add method [[addBin()]] for deserialize [[NumericHistogram]] + * in [[NumericHistogramSerializer]]. + * 4. In Hive's code, the method [[merge()] pass a serialized histogram, + * in Spark, this method pass a deserialized histogram. + * Here we change the code about merge bins. + */ +public class NumericHistogram { + /** + * The Coord class defines a histogram bin, which is just an (x,y) pair. + */ + public static class Coord implements Comparable { + public double x; + public double y; + + public int compareTo(Object other) { + return Double.compare(x, ((Coord) other).x); + } + } + + // Class variables + private int nbins; + private int nusedbins; + private ArrayList<Coord> bins; + private Random prng; + + /** + * Creates a new histogram object. Note that the allocate() or merge() + * method must be called before the histogram can be used. + */ + public NumericHistogram() { + nbins = 0; + nusedbins = 0; + bins = null; + + // init the RNG for breaking ties in histogram merging. A fixed seed is specified here + // to aid testing, but can be eliminated to use a time-based seed (which would + // make the algorithm non-deterministic). + prng = new Random(31183); + } + + /** + * Resets a histogram object to its initial state. allocate() or merge() must be + * called again before use. + */ + public void reset() { + bins = null; + nbins = nusedbins = 0; + } + + /** + * Returns the number of bins. + */ + public int getNumBins() { + return nbins; + } + + /** + * Returns the number of bins currently being used by the histogram. + */ + public int getUsedBins() { + return nusedbins; + } + + /** + * Set the number of bins currently being used by the histogram. + */ + public void setUsedBins(int nusedBins) { + this.nusedbins = nusedBins; + } + + /** + * Returns true if this histogram object has been initialized by calling merge() + * or allocate(). + */ + public boolean isReady() { + return nbins != 0; + } + + /** + * Returns a particular histogram bin. + */ + public Coord getBin(int b) { + return bins.get(b); + } + + /** + * Set a particular histogram bin with index. + */ + public void addBin(double x, double y, int b) { + Coord coord = new Coord(); + coord.x = x; + coord.y = y; + bins.add(b, coord); + } + + /** + * Sets the number of histogram bins to use for approximating data. + * + * @param num_bins Number of non-uniform-width histogram bins to use + */ + public void allocate(int num_bins) { + nbins = num_bins; + bins = new ArrayList<Coord>(); + nusedbins = 0; + } + + /** + * Takes a histogram and merges it with the current histogram object. + */ + public void merge(NumericHistogram other) { + if (other == null) { + return; + } + + if (nbins == 0 || nusedbins == 0) { + // Our aggregation buffer has nothing in it, so just copy over 'other' + // by deserializing the ArrayList of (x,y) pairs into an array of Coord objects + nbins = other.nbins; + nusedbins = other.nusedbins; + bins = new ArrayList<Coord>(nusedbins); + for (int i = 0; i < other.nusedbins; i += 1) { + Coord bin = new Coord(); + bin.x = other.getBin(i).x; + bin.y = other.getBin(i).y; + bins.add(bin); + } + } else { + // The aggregation buffer already contains a partial histogram. Therefore, we need + // to merge histograms using Algorithm #2 from the Ben-Haim and Tom-Tov paper. + + ArrayList<Coord> tmp_bins = new ArrayList<Coord>(nusedbins + other.nusedbins); + // Copy all the histogram bins from us and 'other' into an overstuffed histogram + for (int i = 0; i < nusedbins; i++) { + Coord bin = new Coord(); + bin.x = bins.get(i).x; + bin.y = bins.get(i).y; + tmp_bins.add(bin); + } + for (int j = 0; j < other.nusedbins; j += 1) { + Coord bin = new Coord(); + bin.x = other.getBin(j).x; + bin.y = other.getBin(j).y; + tmp_bins.add(bin); + } + Collections.sort(tmp_bins); + + // Now trim the overstuffed histogram down to the correct number of bins + bins = tmp_bins; + nusedbins += other.nusedbins; + trim(); + } + } + + + /** + * Adds a new data point to the histogram approximation. Make sure you have + * called either allocate() or merge() first. This method implements Algorithm #1 + * from Ben-Haim and Tom-Tov, "A Streaming Parallel Decision Tree Algorithm", JMLR 2010. + * + * @param v The data point to add to the histogram approximation. + */ + public void add(double v) { + // Binary search to find the closest bucket that v should go into. + // 'bin' should be interpreted as the bin to shift right in order to accomodate + // v. As a result, bin is in the range [0,N], where N means that the value v is + // greater than all the N bins currently in the histogram. It is also possible that + // a bucket centered at 'v' already exists, so this must be checked in the next step. + int bin = 0; + for (int l = 0, r = nusedbins; l < r; ) { + bin = (l + r) / 2; + if (bins.get(bin).x > v) { + r = bin; + } else { + if (bins.get(bin).x < v) { + l = ++bin; + } else { + break; // break loop on equal comparator + } + } + } + + // If we found an exact bin match for value v, then just increment that bin's count. + // Otherwise, we need to insert a new bin and trim the resulting histogram back to size. + // A possible optimization here might be to set some threshold under which 'v' is just + // assumed to be equal to the closest bin -- if fabs(v-bins[bin].x) < THRESHOLD, then + // just increment 'bin'. This is not done now because we don't want to make any + // assumptions about the range of numeric data being analyzed. + if (bin < nusedbins && bins.get(bin).x == v) { + bins.get(bin).y++; + } else { + Coord newBin = new Coord(); + newBin.x = v; + newBin.y = 1; + bins.add(bin, newBin); + + // Trim the bins down to the correct number of bins. + if (++nusedbins > nbins) { + trim(); + } + } + + } + + /** + * Trims a histogram down to 'nbins' bins by iteratively merging the closest bins. + * If two pairs of bins are equally close to each other, decide uniformly at random which + * pair to merge, based on a PRNG. + */ + private void trim() { + while (nusedbins > nbins) { + // Find the closest pair of bins in terms of x coordinates. Break ties randomly. + double smallestdiff = bins.get(1).x - bins.get(0).x; + int smallestdiffloc = 0, smallestdiffcount = 1; + for (int i = 1; i < nusedbins - 1; i++) { + double diff = bins.get(i + 1).x - bins.get(i).x; + if (diff < smallestdiff) { + smallestdiff = diff; + smallestdiffloc = i; + smallestdiffcount = 1; + } else { + if (diff == smallestdiff && prng.nextDouble() <= (1.0 / ++smallestdiffcount)) { + smallestdiffloc = i; + } + } + } + + // Merge the two closest bins into their average x location, weighted by their heights. + // The height of the new bin is the sum of the heights of the old bins. + // double d = bins[smallestdiffloc].y + bins[smallestdiffloc+1].y; + // bins[smallestdiffloc].x *= bins[smallestdiffloc].y / d; + // bins[smallestdiffloc].x += bins[smallestdiffloc+1].x / d * + // bins[smallestdiffloc+1].y; + // bins[smallestdiffloc].y = d; + + double d = bins.get(smallestdiffloc).y + bins.get(smallestdiffloc + 1).y; + Coord smallestdiffbin = bins.get(smallestdiffloc); + smallestdiffbin.x *= smallestdiffbin.y / d; + smallestdiffbin.x += bins.get(smallestdiffloc + 1).x / d * bins.get(smallestdiffloc + 1).y; + smallestdiffbin.y = d; + // Shift the remaining bins left one position + bins.remove(smallestdiffloc + 1); + nusedbins--; + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f53c829..4d316ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -433,6 +433,7 @@ object FunctionRegistry { expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), expression[ApproximatePercentile]("approx_percentile", true), + expression[HistogramNumeric]("histogram_numeric"), expression[StddevSamp]("std", true), expression[StddevSamp]("stddev", true), expression[StddevPop]("stddev_pop"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c3cc78e..141de75 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1508,7 +1508,7 @@ class SessionCatalog( */ def isTemporaryFunction(name: FunctionIdentifier): Boolean = { // copied from HiveSessionCatalog - val hiveFunctions = Seq("histogram_numeric") + val hiveFunctions = Seq() // A temporary function is a function that has been registered in functionRegistry // without a database name, and is neither a built-in function nor a Hive function diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala new file mode 100644 index 0000000..09408e6 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumeric.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.nio.ByteBuffer + +import com.google.common.primitives.{Doubles, Ints} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ImplicitCastInputTypes} +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, DateType, DayTimeIntervalType, DoubleType, IntegerType, NumericType, StructField, StructType, TimestampNTZType, TimestampType, TypeCollection, YearMonthIntervalType} +import org.apache.spark.sql.util.NumericHistogram + +/** + * Computes an approximate histogram of a numerical column using a user-specified number of bins. + * + * The output is an array of (x,y) pairs as struct objects that represents the histogram's + * bin centers and heights. + */ +@ExpressionDescription( + usage = """ + _FUNC_(expr, nb) - Computes a histogram on numeric 'expr' using nb bins. + The return value is an array of (x,y) pairs representing the centers of the + histogram's bins. As the value of 'nb' is increased, the histogram approximation + gets finer-grained, but may yield artifacts around outliers. In practice, 20-40 + histogram bins appear to work well, with more bins being required for skewed or + smaller datasets. Note that this function creates a histogram with non-uniform + bin widths. It offers no guarantees in terms of the mean-squared-error of the + histogram, but in practice is comparable to the histograms produced by the R/S-Plus + statistical computing packages. + """, + examples = """ + Examples: + > SELECT _FUNC_(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col); + [{"x":0.0,"y":1.0},{"x":1.0,"y":1.0},{"x":2.0,"y":1.0},{"x":10.0,"y":1.0}] + """, + group = "agg_funcs", + since = "3.3.0") +case class HistogramNumeric( + child: Expression, + nBins: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[NumericHistogram] with ImplicitCastInputTypes + with BinaryLike[Expression] { + + def this(child: Expression, nBins: Expression) = { + this(child, nBins, 0, 0) + } + + private lazy val nb = nBins.eval() match { + case null => null + case n: Int => n + } + + override def inputTypes: Seq[AbstractDataType] = { + // Support NumericType, DateType, TimestampType and TimestampNTZType, YearMonthIntervalType, + // DayTimeIntervalType since their internal types are all numeric, + // and can be easily cast to double for processing. + Seq(TypeCollection(NumericType, DateType, TimestampType, TimestampNTZType, + YearMonthIntervalType, DayTimeIntervalType), IntegerType) + } + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!nBins.foldable) { + TypeCheckFailure(s"${this.prettyName} needs the nBins provided must be a constant literal.") + } else if (nb == null) { + TypeCheckFailure(s"${this.prettyName} needs nBins value must not be null.") + } else if (nb.asInstanceOf[Int] < 2) { + TypeCheckFailure(s"${this.prettyName} needs nBins to be at least 2, but you supplied $nb.") + } else { + TypeCheckSuccess + } + } + + override def createAggregationBuffer(): NumericHistogram = { + val buffer = new NumericHistogram() + buffer.allocate(nb.asInstanceOf[Int]) + buffer + } + + override def update(buffer: NumericHistogram, inputRow: InternalRow): NumericHistogram = { + val value = child.eval(inputRow) + // Ignore empty rows, for example: histogram_numeric(null) + if (value != null) { + // Convert the value to a double value + val doubleValue = value.asInstanceOf[Number].doubleValue + buffer.add(doubleValue) + } + buffer + } + + override def merge( + buffer: NumericHistogram, + other: NumericHistogram): NumericHistogram = { + buffer.merge(other) + buffer + } + + override def eval(buffer: NumericHistogram): Any = { + if (buffer.getUsedBins < 1) { + null + } else { + val result = (0 until buffer.getUsedBins).map { index => + val coord = buffer.getBin(index) + InternalRow.apply(coord.x, coord.y) + } + new GenericArrayData(result) + } + } + + override def serialize(obj: NumericHistogram): Array[Byte] = { + NumericHistogramSerializer.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): NumericHistogram = { + NumericHistogramSerializer.deserialize(bytes) + } + + override def left: Expression = child + + override def right: Expression = nBins + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): HistogramNumeric = { + copy(child = newLeft, nBins = newRight) + } + + override def withNewMutableAggBufferOffset(newOffset: Int): HistogramNumeric = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): HistogramNumeric = + copy(inputAggBufferOffset = newOffset) + + override def nullable: Boolean = true + + override def dataType: DataType = + ArrayType(new StructType(Array( + StructField("x", DoubleType, true), + StructField("y", DoubleType, true))), true) + + override def prettyName: String = "histogram_numeric" +} + +object NumericHistogramSerializer { + private final def length(histogram: NumericHistogram): Int = { + // histogram.nBins, histogram.nUsedBins + Ints.BYTES + Ints.BYTES + + // histogram.bins, Array[Coord(x: Double, y: Double)] + histogram.getUsedBins * (Doubles.BYTES + Doubles.BYTES) + } + + def serialize(histogram: NumericHistogram): Array[Byte] = { + val buffer = ByteBuffer.wrap(new Array(length(histogram))) + buffer.putInt(histogram.getNumBins) + buffer.putInt(histogram.getUsedBins) + + var i = 0 + while (i < histogram.getUsedBins) { + val coord = histogram.getBin(i) + buffer.putDouble(coord.x) + buffer.putDouble(coord.y) + i += 1 + } + buffer.array() + } + + def deserialize(bytes: Array[Byte]): NumericHistogram = { + val buffer = ByteBuffer.wrap(bytes) + val nBins = buffer.getInt() + val nUsedBins = buffer.getInt() + val histogram = new NumericHistogram() + histogram.allocate(nBins) + histogram.setUsedBins(nUsedBins) + var i: Int = 0 + while (i < nUsedBins) { + val x = buffer.getDouble() + val y = buffer.getDouble() + histogram.addBin(x, y, i) + i += 1 + } + histogram + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala new file mode 100644 index 0000000..60b53c6 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HistogramNumericSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.dsl.expressions.{DslString, DslSymbol} +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.types.{DoubleType, IntegerType} +import org.apache.spark.sql.util.NumericHistogram + +class HistogramNumericSuite extends SparkFunSuite { + + private val random = new java.util.Random() + + private val data = (0 until 10000).map { _ => + random.nextInt(10000) + } + + test("serialize and de-serialize") { + + // Check empty serialize and de-serialize + val emptyBuffer = new NumericHistogram() + emptyBuffer.allocate(5) + assert(compareEquals(emptyBuffer, + NumericHistogramSerializer.deserialize(NumericHistogramSerializer.serialize(emptyBuffer)))) + + val buffer = new NumericHistogram() + buffer.allocate(data.size / 3) + data.foreach { value => + buffer.add(value) + } + assert(compareEquals(buffer, + NumericHistogramSerializer.deserialize(NumericHistogramSerializer.serialize(buffer)))) + + val agg = new HistogramNumeric(BoundReference(0, DoubleType, true), Literal(5)) + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) + } + + test("class NumericHistogram, basic operations") { + val valueCount = 5 + Seq(3, 5).foreach { nBins: Int => + val buffer = new NumericHistogram() + buffer.allocate(nBins) + (1 to valueCount).grouped(nBins).foreach { group => + val partialBuffer = new NumericHistogram() + partialBuffer.allocate(nBins) + group.foreach(x => partialBuffer.add(x)) + buffer.merge(partialBuffer) + } + val sum = (0 until buffer.getUsedBins).map { i => + val coord = buffer.getBin(i) + coord.x * coord.y + }.sum + assert(sum <= (1 to valueCount).sum) + } + } + + test("class HistogramNumeric, sql string") { + val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY + assertEqual(s"histogram_numeric(a, 3)", + new HistogramNumeric("a".attr, Literal(3)).sql: String) + + // sql(isDistinct = true), array of percentile + assertEqual(s"histogram_numeric(DISTINCT a, 3)", + new HistogramNumeric("a".attr, Literal(3)).sql(isDistinct = true)) + } + + test("class HistogramNumeric, fails analysis if nBins is not a constant") { + val attribute = AttributeReference("a", IntegerType)() + val wrongNB = new HistogramNumeric(attribute, nBins = AttributeReference("b", IntegerType)()) + + assertEqual( + wrongNB.checkInputDataTypes(), + TypeCheckFailure("histogram_numeric needs the nBins provided must be a constant literal.") + ) + } + + test("class HistogramNumeric, fails analysis if nBins is invalid") { + val attribute = AttributeReference("a", IntegerType)() + val wrongNB = new HistogramNumeric(attribute, nBins = Literal(1)) + + assertEqual( + wrongNB.checkInputDataTypes(), + TypeCheckFailure("histogram_numeric needs nBins to be at least 2, but you supplied 1.") + ) + } + + test("class HistogramNumeric, automatically add type casting for parameters") { + val testRelation = LocalRelation('a.int) + + // accuracy types must be integral, no type casting + val nBinsExpressions = Seq( + Literal(2.toByte), + Literal(100.toShort), + Literal(100), + Literal(1000L)) + + nBinsExpressions.foreach { nBins => + val agg = new HistogramNumeric(UnresolvedAttribute("a"), nBins) + val analyzed = testRelation.select(agg).analyze.expressions.head + analyzed match { + case Alias(agg: HistogramNumeric, _) => + assert(agg.resolved) + assert(agg.child.dataType == IntegerType) + assert(agg.nBins.dataType == IntegerType) + case _ => fail() + } + } + } + + test("HistogramNumeric: nulls in nBins expression") { + assert(new HistogramNumeric( + AttributeReference("a", DoubleType)(), + Literal(null, IntegerType)).checkInputDataTypes() === + TypeCheckFailure("histogram_numeric needs nBins value must not be null.")) + } + + test("class HistogramNumeric, null handling") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val agg = new HistogramNumeric(childExpression, Literal(5)) + val buffer = new GenericInternalRow(new Array[Any](1)) + agg.initialize(buffer) + // Empty aggregation buffer + assert(agg.eval(buffer) == null) + // Empty input row + agg.update(buffer, InternalRow(null)) + assert(agg.eval(buffer) == null) + + // Add some non-empty row + agg.update(buffer, InternalRow(0)) + assert(agg.eval(buffer) != null) + } + + private def compareEquals(left: NumericHistogram, right: NumericHistogram): Boolean = { + left.getNumBins == right.getNumBins && left.getUsedBins == right.getUsedBins && + (0 until left.getUsedBins).forall { i => + val leftCoord = left.getBin(i) + val rightCoord = right.getBin(i) + leftCoord.x == rightCoord.x && leftCoord.y == rightCoord.y + } + } + + private def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } +} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 958b961..9192ac4 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,6 +1,6 @@ <!-- Automatically generated by ExpressionsSchemaSuite --> ## Summary - - Number of queries: 366 + - Number of queries: 367 - Number of expressions that missing example: 12 - Expressions missing examples: bigint,binary,boolean,date,decimal,double,float,int,smallint,string,timestamp,tinyint ## Schema of Built-in Functions @@ -345,6 +345,7 @@ | org.apache.spark.sql.catalyst.expressions.aggregate.CovSample | covar_samp | SELECT covar_samp(c1, c2) FROM VALUES (1,1), (2,2), (3,3) AS tab(c1, c2) | struct<covar_samp(c1, c2):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.First | first | SELECT first(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first(col):int> | | org.apache.spark.sql.catalyst.expressions.aggregate.First | first_value | SELECT first_value(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<first_value(col):int> | +| org.apache.spark.sql.catalyst.expressions.aggregate.HistogramNumeric | histogram_numeric | SELECT histogram_numeric(col, 5) FROM VALUES (0), (1), (2), (10) AS tab(col) | struct<histogram_numeric(col, 5):array<struct<x:double,y:double>>> | | org.apache.spark.sql.catalyst.expressions.aggregate.HyperLogLogPlusPlus | approx_count_distinct | SELECT approx_count_distinct(col1) FROM VALUES (1), (1), (2), (2), (3) tab(col1) | struct<approx_count_distinct(col1):bigint> | | org.apache.spark.sql.catalyst.expressions.aggregate.Kurtosis | kurtosis | SELECT kurtosis(col) FROM VALUES (-10), (-20), (100), (1000) AS tab(col) | struct<kurtosis(col):double> | | org.apache.spark.sql.catalyst.expressions.aggregate.Last | last | SELECT last(col) FROM VALUES (10), (5), (20) AS tab(col) | struct<last(col):int> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 039373b..4e6d2d2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -192,3 +192,15 @@ SELECT if(not(a IS NULL), rand(0), 1), count(*) AS c FROM testData GROUP BY a IS NULL; + +SELECT + histogram_numeric(col, 2) as histogram_2, + histogram_numeric(col, 3) as histogram_3, + histogram_numeric(col, 5) as histogram_5, + histogram_numeric(col, 10) as histogram_10 +FROM VALUES + (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), + (11), (12), (13), (14), (15), (16), (17), (18), (19), (20), + (21), (22), (23), (24), (25), (26), (27), (28), (29), (30), + (31), (32), (33), (34), (35), (3), (37), (38), (39), (40), + (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col); diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index f598f49..5cd5a37 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 65 +-- Number of queries: 66 -- !query @@ -673,3 +673,21 @@ struct<(IF((NOT (a IS NULL)), rand(0), 1)):double,c:bigint> -- !query output 0.7604953758285915 7 1.0 2 + + +-- !query +SELECT + histogram_numeric(col, 2) as histogram_2, + histogram_numeric(col, 3) as histogram_3, + histogram_numeric(col, 5) as histogram_5, + histogram_numeric(col, 10) as histogram_10 +FROM VALUES + (1), (2), (3), (4), (5), (6), (7), (8), (9), (10), + (11), (12), (13), (14), (15), (16), (17), (18), (19), (20), + (21), (22), (23), (24), (25), (26), (27), (28), (29), (30), + (31), (32), (33), (34), (35), (3), (37), (38), (39), (40), + (41), (42), (43), (44), (45), (46), (47), (48), (49), (50) AS tab(col) +-- !query schema +struct<histogram_2:array<struct<x:double,y:double>>,histogram_3:array<struct<x:double,y:double>>,histogram_5:array<struct<x:double,y:double>>,histogram_10:array<struct<x:double,y:double>>> +-- !query output +[{"x":12.615384615384613,"y":26.0},{"x":38.083333333333336,"y":24.0}] [{"x":9.649999999999999,"y":20.0},{"x":25.0,"y":11.0},{"x":40.736842105263165,"y":19.0}] [{"x":5.272727272727273,"y":11.0},{"x":14.5,"y":8.0},{"x":22.0,"y":7.0},{"x":30.499999999999996,"y":10.0},{"x":43.5,"y":14.0}] [{"x":3.0,"y":6.0},{"x":8.5,"y":6.0},{"x":13.5,"y":4.0},{"x":17.0,"y":3.0},{"x":20.5,"y":4.0},{"x":25.5,"y":6.0},{"x":31.999999999999996,"y":7.0},{"x":39.0,"y":5.0},{"x":43.5,"y":4.0},{"x":48.0,"y":5.0}] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 488890a..b11774b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -128,7 +128,5 @@ private[sql] class HiveSessionCatalog( // in_file, index, matchpath, ngrams, noop, noopstreaming, noopwithmap, // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction. // Note: don't forget to update SessionCatalog.isTemporaryFunction - private val hiveFunctions = Seq( - "histogram_numeric" - ) + private val hiveFunctions = Seq() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org