mkaravel commented on code in PR #51298: URL: https://github.com/apache/spark/pull/51298#discussion_r2330463138
########## python/pyspark/sql/functions/builtin.py: ########## @@ -15799,7 +15799,8 @@ def regexp_count(str: "ColumnOrName", regexp: "ColumnOrName") -> Column: @_try_remote_functions def regexp_extract(str: "ColumnOrName", pattern: str, idx: int) -> Column: - r"""Extract a specific group matched by the Java regex `regexp`, from the specified string column. + r"""Extract a specific group matched by the Java regex `regexp`, from the specified + string column. Review Comment: I would not make this change as part of this PR. It is unrelated to the actual purpose of the PR (which is to introduce Theta sketches functions). ########## python/pyspark/sql/functions/builtin.py: ########## @@ -25704,6 +25705,385 @@ def hll_union( return _invoke_function("hll_union", _to_java_column(col1), _to_java_column(col2)) +@_try_remote_functions +def theta_sketch_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch configured with lgNomEntries arg. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + lgNomEntries : :class:`~pyspark.sql.Column` or int, optional + The log-base-2 of nominal entries, where nominal entries is the size of the sketch + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union` + :meth:`pyspark.sql.functions.theta_intersection` + :meth:`pyspark.sql.functions.theta_difference` + :meth:`pyspark.sql.functions.theta_union_agg` + :meth:`pyspark.sql.functions.theta_intersection_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([1,2,2,3], "INT") + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 12))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 15))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + """ + fn = "theta_sketch_agg" + if lgNomEntries is None: + return _invoke_function_over_columns(fn, col) + else: + return _invoke_function_over_columns(fn, col, lit(lgNomEntries)) + + +@_try_remote_functions +def theta_union_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch, generated by merging previously created Datasketches ThetaSketch instances + via a Datasketches Union instance. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + lgNomEntries : :class:`~pyspark.sql.Column` or int, optional + The log-base-2 of nominal entries for the union operation + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the merged ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union` + :meth:`pyspark.sql.functions.theta_sketch_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df1 = spark.createDataFrame([1,2,2,3], "INT") + >>> df1 = df1.agg(sf.theta_sketch_agg("value").alias("sketch")) + >>> df2 = spark.createDataFrame([4,5,5,6], "INT") + >>> df2 = df2.agg(sf.theta_sketch_agg("value").alias("sketch")) + >>> df3 = df1.union(df2) + >>> df3.agg(sf.theta_sketch_estimate(sf.theta_union_agg("sketch"))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_union_agg(sketch, 12))| + +--------------------------------------------------+ + | 6| + +--------------------------------------------------+ + """ + fn = "theta_union_agg" + if lgNomEntries is None: + return _invoke_function_over_columns(fn, col) + else: + return _invoke_function_over_columns(fn, col, lit(lgNomEntries)) + + +@_try_remote_functions +def theta_intersection_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch, generated by intersecting previously created Datasketches ThetaSketch + instances via a Datasketches Intersection instance. Review Comment: Similar comment here. ########## python/pyspark/sql/functions/builtin.py: ########## @@ -25704,6 +25705,385 @@ def hll_union( return _invoke_function("hll_union", _to_java_column(col1), _to_java_column(col2)) +@_try_remote_functions +def theta_sketch_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch configured with lgNomEntries arg. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + lgNomEntries : :class:`~pyspark.sql.Column` or int, optional + The log-base-2 of nominal entries, where nominal entries is the size of the sketch + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union` + :meth:`pyspark.sql.functions.theta_intersection` + :meth:`pyspark.sql.functions.theta_difference` + :meth:`pyspark.sql.functions.theta_union_agg` + :meth:`pyspark.sql.functions.theta_intersection_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([1,2,2,3], "INT") + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 12))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 15))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + """ + fn = "theta_sketch_agg" + if lgNomEntries is None: + return _invoke_function_over_columns(fn, col) + else: + return _invoke_function_over_columns(fn, col, lit(lgNomEntries)) + + +@_try_remote_functions +def theta_union_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch, generated by merging previously created Datasketches ThetaSketch instances + via a Datasketches Union instance. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + lgNomEntries : :class:`~pyspark.sql.Column` or int, optional + The log-base-2 of nominal entries for the union operation + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the merged ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union` + :meth:`pyspark.sql.functions.theta_sketch_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df1 = spark.createDataFrame([1,2,2,3], "INT") + >>> df1 = df1.agg(sf.theta_sketch_agg("value").alias("sketch")) + >>> df2 = spark.createDataFrame([4,5,5,6], "INT") + >>> df2 = df2.agg(sf.theta_sketch_agg("value").alias("sketch")) + >>> df3 = df1.union(df2) + >>> df3.agg(sf.theta_sketch_estimate(sf.theta_union_agg("sketch"))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_union_agg(sketch, 12))| + +--------------------------------------------------+ + | 6| + +--------------------------------------------------+ + """ + fn = "theta_union_agg" + if lgNomEntries is None: + return _invoke_function_over_columns(fn, col) + else: + return _invoke_function_over_columns(fn, col, lit(lgNomEntries)) + + +@_try_remote_functions +def theta_intersection_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch, generated by intersecting previously created Datasketches ThetaSketch + instances via a Datasketches Intersection instance. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + lgNomEntries : :class:`~pyspark.sql.Column` or int, optional + The log-base-2 of nominal entries for the intersection operation + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the intersected ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_intersection` + :meth:`pyspark.sql.functions.theta_sketch_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df1 = spark.createDataFrame([1,2,2,3], "INT") + >>> df1 = df1.agg(sf.theta_sketch_agg("value").alias("sketch")) + >>> df2 = spark.createDataFrame([2,3,3,4], "INT") + >>> df2 = df2.agg(sf.theta_sketch_agg("value").alias("sketch")) + >>> df3 = df1.union(df2) + >>> df3.agg(sf.theta_sketch_estimate(sf.theta_intersection_agg("sketch"))).show() + +---------------------------------------------------------+ + |theta_sketch_estimate(theta_intersection_agg(sketch, 12))| + +---------------------------------------------------------+ + | 2| + +---------------------------------------------------------+ + """ + fn = "theta_intersection_agg" + if lgNomEntries is None: + return _invoke_function_over_columns(fn, col) + else: + return _invoke_function_over_columns(fn, col, lit(lgNomEntries)) + + +@_try_remote_functions +def theta_sketch_estimate(col: "ColumnOrName") -> Column: + """ + Returns the estimated number of unique values given the binary representation + of a Datasketches ThetaSketch. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + + Returns + ------- + :class:`~pyspark.sql.Column` + The estimated number of unique values for the ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union` + :meth:`pyspark.sql.functions.theta_intersection` + :meth:`pyspark.sql.functions.theta_difference` + :meth:`pyspark.sql.functions.theta_union_agg` + :meth:`pyspark.sql.functions.theta_intersection_agg` + :meth:`pyspark.sql.functions.theta_sketch_agg` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([1,2,2,3], "INT") + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 12))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + """ + from pyspark.sql.classic.column import _to_java_column + + fn = "theta_sketch_estimate" + return _invoke_function(fn, _to_java_column(col)) + + +@_try_remote_functions +def theta_union( + col1: "ColumnOrName", col2: "ColumnOrName", lgNomEntries: Optional[int] = None +) -> Column: + """ + Merges two binary representations of Datasketches ThetaSketch objects, using a + Datasketches Union object. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col1 : :class:`~pyspark.sql.Column` or column name + col2 : :class:`~pyspark.sql.Column` or column name + lgNomEntries : int, optional + The log-base-2 of nominal entries for the union operation + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the merged ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union_agg` + :meth:`pyspark.sql.functions.theta_sketch_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([(1,4),(2,5),(2,5),(3,6)], "struct<v1:int,v2:int>") + >>> df = df.agg( + ... sf.theta_sketch_agg("v1").alias("sketch1"), + ... sf.theta_sketch_agg("v2").alias("sketch2") + ... ) + >>> df.select(sf.theta_sketch_estimate(sf.theta_union(df.sketch1, "sketch2"))).show() + +--------------------------------------------------------+ + |theta_sketch_estimate(theta_union(sketch1, sketch2, 12))| + +--------------------------------------------------------+ + | 6| + +--------------------------------------------------------+ + """ + from pyspark.sql.classic.column import _to_java_column + + fn = "theta_union" + if lgNomEntries is not None: + return _invoke_function( + fn, + _to_java_column(col1), + _to_java_column(col2), + _enum_to_value(lgNomEntries), Review Comment: Why do we use `_enum_to_value` here instead of `lit`? Same question for other functions below. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala: ########## @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder} + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, DataType, DoubleType, FloatType, IntegerType, LongType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String + +sealed trait ThetaSketchState { + def serialize(): Array[Byte] + def eval(): Array[Byte] +} +case class UpdatableSketchBuffer(sketch: UpdateSketch) extends ThetaSketchState { + override def serialize(): Array[Byte] = sketch.rebuild.compact.toByteArrayCompressed + override def eval(): Array[Byte] = sketch.rebuild.compact.toByteArrayCompressed +} +case class UnionAggregationBuffer(union: Union) extends ThetaSketchState { + override def serialize(): Array[Byte] = union.getResult.toByteArrayCompressed + override def eval(): Array[Byte] = union.getResult.toByteArrayCompressed +} +case class IntersectionAggregationBuffer(intersection: Intersection) extends ThetaSketchState { + override def serialize(): Array[Byte] = intersection.getResult.toByteArrayCompressed + override def eval(): Array[Byte] = intersection.getResult.toByteArrayCompressed +} +case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState { + override def serialize(): Array[Byte] = sketch.toByteArrayCompressed + override def eval(): Array[Byte] = sketch.toByteArrayCompressed +} + +/** + * The ThetaSketchAgg function utilizes a Datasketches ThetaSketch instance to count a + * probabilistic approximation of the number of unique values in a given column, and outputs the + * binary representation of the ThetaSketch. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param left + * child expression against which unique counting will occur + * @param right + * the log-base-2 of nomEntries decides the number of buckets for the sketch + * @param mutableAggBufferOffset + * offset for mutable aggregation buffer + * @param inputAggBufferOffset + * offset for input aggregation buffer + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation. + `lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding + the number buckets or slots for the ThetaSketch. """, + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col); + 3 + """, + group = "agg_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit +case class ThetaSketchAgg( + left: Expression, + right: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[ThetaSketchState] + with BinaryLike[Expression] + with ExpectsInputTypes { + + // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. + + lazy val lgNomEntries: Int = { + val lgNomEntriesInput = right.eval().asInstanceOf[Int] + ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName) + lgNomEntriesInput + } + + // Constructors + + def this(child: Expression) = { + this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0) + } + + def this(child: Expression, lgNomEntries: Expression) = { + this(child, lgNomEntries, 0, 0) + } + + def this(child: Expression, lgNomEntries: Int) = { + this(child, Literal(lgNomEntries), 0, 0) + } + + // Copy constructors required by ImperativeAggregate + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaSketchAgg = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): ThetaSketchAgg = + copy(left = newLeft, right = newRight) + + // Overrides for TypedImperativeAggregate + + override def prettyName: String = "theta_sketch_agg" + + override def inputTypes: Seq[AbstractDataType] = + Seq( + TypeCollection( + ArrayType(IntegerType), + ArrayType(LongType), + BinaryType, + DoubleType, + FloatType, + IntegerType, + LongType, + StringTypeWithCollation(supportsTrimCollation = true)), + IntegerType) + + override def dataType: DataType = BinaryType + + override def nullable: Boolean = false + + /** + * Instantiate an UpdateSketch instance using the lgNomEntries param. + * + * @return + * an UpdateSketch instance wrapped with UpdatableSketchBuffer + */ + override def createAggregationBuffer(): ThetaSketchState = { + val builder = new UpdateSketchBuilder + builder.setLogNominalEntries(lgNomEntries) + UpdatableSketchBuffer(builder.build) + } + + /** + * Evaluate the input row and update the UpdateSketch instance with the row's value. The update + * function only supports a subset of Spark SQL types, and an exception will be thrown for + * unsupported types. + * + * @param updateBuffer + * A previously initialized UpdateSketch instance + * @param input + * An input row + */ + override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = { + // Return early for null values. + val v = left.eval(input) + if (v == null) return updateBuffer + + // Initialized buffer should be UpdatableSketchBuffer, else error out. + val sketch = updateBuffer match { + case UpdatableSketchBuffer(s) => s + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + + // Handle the different data types for sketch updates. + left.dataType match { + case ArrayType(IntegerType, _) => + val arr = v.asInstanceOf[ArrayData].toIntArray() + if (arr.nonEmpty) sketch.update(arr) Review Comment: Here and below (array of longs and byte arrays): Why do we not update the sketch when the input array is empty? Specifically, for an empty byte array (which is a valid value of type BINARY) doesn't the current logic completely disregard this valid value as if it never existed in the column? ########## python/pyspark/sql/functions/builtin.py: ########## @@ -25704,6 +25705,385 @@ def hll_union( return _invoke_function("hll_union", _to_java_column(col1), _to_java_column(col2)) +@_try_remote_functions +def theta_sketch_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch configured with lgNomEntries arg. + + .. versionadded:: 4.1.0 + + Parameters + ---------- + col : :class:`~pyspark.sql.Column` or column name + lgNomEntries : :class:`~pyspark.sql.Column` or int, optional + The log-base-2 of nominal entries, where nominal entries is the size of the sketch + (must be between 4 and 26, defaults to 12) + + Returns + ------- + :class:`~pyspark.sql.Column` + The binary representation of the ThetaSketch. + + See Also + -------- + :meth:`pyspark.sql.functions.theta_union` + :meth:`pyspark.sql.functions.theta_intersection` + :meth:`pyspark.sql.functions.theta_difference` + :meth:`pyspark.sql.functions.theta_union_agg` + :meth:`pyspark.sql.functions.theta_intersection_agg` + :meth:`pyspark.sql.functions.theta_sketch_estimate` + + Examples + -------- + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([1,2,2,3], "INT") + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 12))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + + >>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15))).show() + +--------------------------------------------------+ + |theta_sketch_estimate(theta_sketch_agg(value, 15))| + +--------------------------------------------------+ + | 3| + +--------------------------------------------------+ + """ + fn = "theta_sketch_agg" + if lgNomEntries is None: + return _invoke_function_over_columns(fn, col) + else: + return _invoke_function_over_columns(fn, col, lit(lgNomEntries)) + + +@_try_remote_functions +def theta_union_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch, generated by merging previously created Datasketches ThetaSketch instances + via a Datasketches Union instance. Review Comment: ```suggestion Aggregate function: Returns the compact binary representation of the DataSketches' ThetaSketch that is the union of the Theta sketches in the input column. ``` ########## python/pyspark/sql/functions/builtin.py: ########## @@ -25704,6 +25705,385 @@ def hll_union( return _invoke_function("hll_union", _to_java_column(col1), _to_java_column(col2)) +@_try_remote_functions +def theta_sketch_agg( + col: "ColumnOrName", + lgNomEntries: Optional[Union[int, Column]] = None, +) -> Column: + """ + Aggregate function: returns the compact binary representation of the Datasketches + ThetaSketch configured with lgNomEntries arg. Review Comment: What I am missing here is what we Theta sketch refers to. I would expected something like: ``` Aggregate function: Returns the compact binary representation of the DataSketches' ThetaSketch of the values in the input column, configured with `lgNomEntries` nominal entries. ``` ########## sql/api/src/main/scala/org/apache/spark/sql/functions.scala: ########## @@ -1165,6 +1165,169 @@ object functions { */ def sum_distinct(e: Column): Column = Column.fn("sum", isDistinct = true, e) + /** + * Aggregate function: returns the compact binary representation of the Datasketches + * ThetaSketch, generated by intersecting previously created Datasketches ThetaSketch instances + * via a Datasketches Intersection instance. Allows setting of log nominal entries for the + * intersection buffer. Review Comment: I have seen this pattern in many similar comments: instead of explaining what the function does, the explanation is more in the direction of _how_ the function operates. I understand that explaining how the function operates is a correct way of explaining what it does, but is this really useful to the reader? I guess my question is more relevant when it comes to documenting the functions (see my comments in the PySpark APIs). ########## sql/api/src/main/scala/org/apache/spark/sql/functions.scala: ########## Review Comment: I am bit confused regarding why the functions in this file have been placed where they have been placed. Why are they not grouped together? Are we trying to keep some kind of alphabetical ordering (based on the function name)? ########## sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ThetasketchesAggSuite.scala: ########## @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import scala.collection.immutable.NumericRange +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{BoundReference, ThetaSketchEstimate} +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types.{ArrayType, BinaryType, DataType, DoubleType, FloatType, IntegerType, LongType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class ThetasketchesAggSuite extends SparkFunSuite { + + def simulateUpdateMerge( + dataType: DataType, + input: Seq[Any], + numSketches: Integer = 5): (Long, NumericRange[Long]) = { + + // Create a map of the agg function instances. + val aggFunctionMap = Seq + .tabulate(numSketches)(index => { + val sketch = new ThetaSketchAgg(BoundReference(0, dataType, nullable = true)) + index -> (sketch, sketch.createAggregationBuffer()) + }) + .toMap + + // Randomly update agg function instances. + input.map(value => { + val (aggFunction, aggBuffer) = aggFunctionMap(Random.nextInt(numSketches)) + aggFunction.update(aggBuffer, InternalRow(value)) + }) + + def serializeDeserialize( + tuple: (ThetaSketchAgg, ThetaSketchState)): (ThetaSketchAgg, ThetaSketchState) = { + val (agg, buf) = tuple + val serialized = agg.serialize(buf) + (agg, agg.deserialize(serialized)) + } + + // Simulate serialization -> deserialization -> merge. + val mapValues = aggFunctionMap.values + val (mergedAgg, UnionAggregationBuffer(mergedBuf)) = + mapValues.tail.foldLeft(mapValues.head)((prev, cur) => { + val (prevAgg, prevBuf) = serializeDeserialize(prev) + val (_, curBuf) = serializeDeserialize(cur) + + (prevAgg, prevAgg.merge(prevBuf, curBuf)) + }) + + val estimator = ThetaSketchEstimate(BoundReference(0, BinaryType, nullable = true)) + val estimate = + estimator.eval(InternalRow(mergedBuf.getResult.toByteArrayCompressed)).asInstanceOf[Long] + ( + estimate, + mergedBuf.getResult.getLowerBound(3).toLong to mergedBuf.getResult.getUpperBound(3).toLong) + } + + test("SPARK-52407: Test min/max values of supported datatypes") { + val intRange = Integer.MIN_VALUE to Integer.MAX_VALUE by 10000000 + val (intEstimate, intEstimateRange) = simulateUpdateMerge(IntegerType, intRange) + assert(intEstimate == intRange.size || intEstimateRange.contains(intRange.size.toLong)) + + val longRange = Long.MinValue to Long.MaxValue by 1000000000000000L + val (longEstimate, longEstimateRange) = simulateUpdateMerge(LongType, longRange) + assert(longEstimate == longRange.size || longEstimateRange.contains(longRange.size.toLong)) + + val stringRange = Seq.tabulate(1000)(i => UTF8String.fromString(Random.nextString(i + 1))) + val (stringEstimate, stringEstimateRange) = simulateUpdateMerge(StringType, stringRange) + assert( + stringEstimate == stringRange.size || + stringEstimateRange.contains(stringRange.size.toLong)) + + val binaryRange = + Seq.tabulate(1000)(i => UTF8String.fromString(Random.nextString(i + 1)).getBytes) + val (binaryEstimate, binaryEstimateRange) = simulateUpdateMerge(BinaryType, binaryRange) + assert( + binaryEstimate == binaryRange.size || + binaryEstimateRange.contains(binaryRange.size.toLong)) + + val floatRange = (1 to 1000).map(_.toFloat) + val (floatEstimate, floatRangeEst) = simulateUpdateMerge(FloatType, floatRange) + assert(floatEstimate == floatRange.size || floatRangeEst.contains(floatRange.size.toLong)) + + val doubleRange = (1 to 1000).map(_.toDouble) + val (doubleEstimate, doubleRangeEst) = simulateUpdateMerge(DoubleType, doubleRange) + assert(doubleEstimate == doubleRange.size || doubleRangeEst.contains(doubleRange.size.toLong)) + + val arrayIntRange = (1 to 500).map(i => ArrayData.toArrayData(Array(i, i + 1))) + val (arrayIntEstimate, arrayIntRangeEst) = + simulateUpdateMerge(ArrayType(IntegerType), arrayIntRange) + assert( + arrayIntEstimate == arrayIntRange.size || + arrayIntRangeEst.contains(arrayIntRange.size.toLong)) + + val arrayLongRange = + (1 to 500).map(i => ArrayData.toArrayData(Array(i.toLong, (i + 1).toLong))) + val (arrayLongEstimate, arrayLongRangeEst) = + simulateUpdateMerge(ArrayType(LongType), arrayLongRange) + assert( + arrayLongEstimate == arrayLongRange.size || + arrayLongRangeEst.contains(arrayLongRange.size.toLong)) + } + + test("SPARK-52407: Test lgNomEntries results in downsampling sketches during Union") { + // Create a sketch with larger configuration (more precise). + val aggFunc1 = new ThetaSketchAgg(BoundReference(0, IntegerType, nullable = true), 12) + val sketch1 = aggFunc1.createAggregationBuffer() + (0 to 100).map(i => aggFunc1.update(sketch1, InternalRow(i))) + val binary1 = aggFunc1.eval(sketch1) + + // Create a sketch with smaller configuration (less precise). + val aggFunc2 = new ThetaSketchAgg(BoundReference(0, IntegerType, nullable = true), 10) + val sketch2 = aggFunc2.createAggregationBuffer() + (0 to 100).map(i => aggFunc2.update(sketch2, InternalRow(i))) + val binary2 = aggFunc2.eval(sketch2) + + // Union the sketches. + val unionAgg = new ThetaUnionAgg(BoundReference(0, BinaryType, nullable = true), 12) Review Comment: Should we allow this to happen so silently? Users can very easily union sketches and ask for a higher nominal value than the sketches (one or both) support. In this case we would break any accuracy guarantees the sketches individually might have, again all very silently. To be honest I am in favor of one of two approaches: * Do not allow to specify a nominal value when doing boolean set operations. In this case the behavior is that we choose the less precise nominal value for the result. * We provide a function to change the nominal value of a sketch (in either direction, lower or higher). This way we force users to be aware of what they are doing as opposed to silently pretend that we were able to create a sketch with the specified precision. Having said the above, we need to consider if we need one more scalar function that returns the nominal value of a sketch. ########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/thetasketchesAggregates.scala: ########## @@ -0,0 +1,662 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.datasketches.memory.Memory +import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder} + +import org.apache.spark.SparkUnsupportedOperationException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate +import org.apache.spark.sql.catalyst.trees.BinaryLike +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.internal.types.StringTypeWithCollation +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, DataType, DoubleType, FloatType, IntegerType, LongType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String + +sealed trait ThetaSketchState { + def serialize(): Array[Byte] + def eval(): Array[Byte] +} +case class UpdatableSketchBuffer(sketch: UpdateSketch) extends ThetaSketchState { + override def serialize(): Array[Byte] = sketch.rebuild.compact.toByteArrayCompressed + override def eval(): Array[Byte] = sketch.rebuild.compact.toByteArrayCompressed +} +case class UnionAggregationBuffer(union: Union) extends ThetaSketchState { + override def serialize(): Array[Byte] = union.getResult.toByteArrayCompressed + override def eval(): Array[Byte] = union.getResult.toByteArrayCompressed +} +case class IntersectionAggregationBuffer(intersection: Intersection) extends ThetaSketchState { + override def serialize(): Array[Byte] = intersection.getResult.toByteArrayCompressed + override def eval(): Array[Byte] = intersection.getResult.toByteArrayCompressed +} +case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState { + override def serialize(): Array[Byte] = sketch.toByteArrayCompressed + override def eval(): Array[Byte] = sketch.toByteArrayCompressed +} + +/** + * The ThetaSketchAgg function utilizes a Datasketches ThetaSketch instance to count a + * probabilistic approximation of the number of unique values in a given column, and outputs the + * binary representation of the ThetaSketch. + * + * See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information. + * + * @param left + * child expression against which unique counting will occur + * @param right + * the log-base-2 of nomEntries decides the number of buckets for the sketch + * @param mutableAggBufferOffset + * offset for mutable aggregation buffer + * @param inputAggBufferOffset + * offset for input aggregation buffer + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation. + `lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding + the number buckets or slots for the ThetaSketch. """, + examples = """ + Examples: + > SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col); + 3 + """, + group = "agg_funcs", + since = "4.1.0") +// scalastyle:on line.size.limit +case class ThetaSketchAgg( + left: Expression, + right: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) + extends TypedImperativeAggregate[ThetaSketchState] + with BinaryLike[Expression] + with ExpectsInputTypes { + + // ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation. + + lazy val lgNomEntries: Int = { + val lgNomEntriesInput = right.eval().asInstanceOf[Int] + ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName) + lgNomEntriesInput + } + + // Constructors + + def this(child: Expression) = { + this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0) + } + + def this(child: Expression, lgNomEntries: Expression) = { + this(child, lgNomEntries, 0, 0) + } + + def this(child: Expression, lgNomEntries: Int) = { + this(child, Literal(lgNomEntries), 0, 0) + } + + // Copy constructors required by ImperativeAggregate + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ThetaSketchAgg = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ThetaSketchAgg = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + override protected def withNewChildrenInternal( + newLeft: Expression, + newRight: Expression): ThetaSketchAgg = + copy(left = newLeft, right = newRight) + + // Overrides for TypedImperativeAggregate + + override def prettyName: String = "theta_sketch_agg" + + override def inputTypes: Seq[AbstractDataType] = + Seq( + TypeCollection( + ArrayType(IntegerType), + ArrayType(LongType), + BinaryType, + DoubleType, + FloatType, + IntegerType, + LongType, + StringTypeWithCollation(supportsTrimCollation = true)), + IntegerType) + + override def dataType: DataType = BinaryType + + override def nullable: Boolean = false + + /** + * Instantiate an UpdateSketch instance using the lgNomEntries param. + * + * @return + * an UpdateSketch instance wrapped with UpdatableSketchBuffer + */ + override def createAggregationBuffer(): ThetaSketchState = { + val builder = new UpdateSketchBuilder + builder.setLogNominalEntries(lgNomEntries) + UpdatableSketchBuffer(builder.build) + } + + /** + * Evaluate the input row and update the UpdateSketch instance with the row's value. The update + * function only supports a subset of Spark SQL types, and an exception will be thrown for + * unsupported types. + * + * @param updateBuffer + * A previously initialized UpdateSketch instance + * @param input + * An input row + */ + override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = { + // Return early for null values. + val v = left.eval(input) + if (v == null) return updateBuffer + + // Initialized buffer should be UpdatableSketchBuffer, else error out. + val sketch = updateBuffer match { + case UpdatableSketchBuffer(s) => s + case _ => throw QueryExecutionErrors.thetaInvalidInputSketchBuffer(prettyName) + } + + // Handle the different data types for sketch updates. + left.dataType match { + case ArrayType(IntegerType, _) => + val arr = v.asInstanceOf[ArrayData].toIntArray() + if (arr.nonEmpty) sketch.update(arr) + case ArrayType(LongType, _) => + val arr = v.asInstanceOf[ArrayData].toLongArray() + if (arr.nonEmpty) sketch.update(arr) + case BinaryType => + val bytes = v.asInstanceOf[Array[Byte]] + if (bytes.nonEmpty) sketch.update(bytes) + case DoubleType => + sketch.update(v.asInstanceOf[Double]) + case FloatType => + sketch.update(v.asInstanceOf[Float].toDouble) // Float is promoted to double. + case IntegerType => + sketch.update(v.asInstanceOf[Int].toLong) // Int is promoted to Long. + case LongType => + sketch.update(v.asInstanceOf[Long]) + case st: StringType => + val cKey = + CollationFactory.getCollationKey(v.asInstanceOf[UTF8String], st.collationId) + sketch.update(cKey.toString) Review Comment: I see a problem with this implementation. Technically speaking, the collation key of a string is a byte array. Here we use the `CollationFactory` API that produces the collation key as a `UTF8String` (basically this populates a `UTF8String` object with the bytes of the collation key), but then we convert `cKey` to a Java string. I believe the latter conversion is extremely problematic because Java constructor used in the `toString()` implementation will convert invalid UTF8 sequences to U+FFFD, and this is very likely if the input is a byte array like the one we get as the collation key of the string. Basically, what I am saying is that here we may end up considering as the same strings that are different just because of the way we construct the string we pass to the sketch. I think a safer way is the following: ```scala val cKey = CollationFactory.getCollationKeyBytes(v.asInstanceOf[UTF8String], st.collationId) sketch.update(cKey) ``` What is not clear to me in my suggestion above is what happens with empty strings, because they might produce an empty byte array as the collation key. Let's try to figure this out. ########## sql/core/src/test/resources/sql-tests/results/thetasketch.sql.out: ########## @@ -0,0 +1,1208 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +DROP TABLE IF EXISTS t_int_1_5_through_7_11 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_int_1_5_through_7_11 AS +VALUES + (1, 5), (2, 6), (3, 7), (4, 8), (5, 9), (6, 10), (7, 11) AS tab(col1, col2) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t_long_1_5_through_7_11 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_long_1_5_through_7_11 AS +VALUES + (1L, 5L), (2L, 6L), (3L, 7L), (4L, 8L), (5L, 9L), (6L, 10L), (7L, 11L) AS tab(col1, col2) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t_double_1_1_1_4_through_1_5_1_8 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_double_1_1_1_4_through_1_5_1_8 AS +SELECT CAST(col1 AS DOUBLE) AS col1, CAST(col2 AS DOUBLE) AS col2 +FROM VALUES + (1.1, 1.4), (1.2, 1.5), (1.3, 1.6), (1.4, 1.7), (1.5, 1.8) AS tab(col1, col2) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t_float_1_1_1_4_through_1_5_1_8 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_float_1_1_1_4_through_1_5_1_8 AS +SELECT CAST(col1 AS FLOAT) col1, CAST(col2 AS FLOAT) col2 +FROM VALUES + (1.1, 1.4), (1.2, 1.5), (1.3, 1.6), (1.4, 1.7), (1.5, 1.8) AS tab(col1, col2) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t_string_a_d_through_e_h +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_string_a_d_through_e_h AS +VALUES + ('a', 'd'), ('b', 'e'), ('c', 'f'), ('d', 'g'), ('e', 'h') AS tab(col1, col2) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t_binary_a_b_through_e_f +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_binary_a_b_through_e_f AS +VALUES + (X'A', X'B'), (X'B', X'C'), (X'C', X'D'), (X'D', X'E'), (X'E', X'F') AS tab(col1, col2) +-- !query schema +struct<> +-- !query output + + + +-- !query +DROP TABLE IF EXISTS t_array_int_1_3_through_4_6 +-- !query schema +struct<> +-- !query output + + + +-- !query +CREATE TABLE t_array_int_1_3_through_4_6 AS +VALUES + (ARRAY(1), ARRAY(3)), + (ARRAY(2), ARRAY(4)), + (ARRAY(3), ARRAY(5)), + (ARRAY(4), ARRAY(6)) AS tab(col1, col2) +-- !query schema +struct<> +-- !query output Review Comment: Could we add an example where we have arrays with more than one element? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org