cboumalh commented on code in PR #52800: URL: https://github.com/apache/spark/pull/52800#discussion_r2612518662
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala: ########## @@ -0,0 +1,671 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.datasketches.kll.{KllDoublesSketch, KllFloatsSketch, KllLongsSketch} +import org.apache.datasketches.memory.Memory + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, DataType, DoubleType, FloatType, LongType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns human readable summary information about this sketch. + """, + examples = """ + Examples: + > SELECT LENGTH(_FUNC_(kll_sketch_agg_bigint(col))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchToStringBigint(child: Expression) extends KllSketchToStringBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchToStringBigint = + copy(child = newChild) + override def prettyName: String = "kll_sketch_to_string_bigint" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllLongsSketch.heapify(Memory.wrap(buffer)) + UTF8String.fromString(sketch.toString()) + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns human readable summary information about this sketch. + """, + examples = """ + Examples: + > SELECT LENGTH(_FUNC_(kll_sketch_agg_float(col))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchToStringFloat(child: Expression) extends KllSketchToStringBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchToStringFloat = + copy(child = newChild) + override def prettyName: String = "kll_sketch_to_string_float" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer)) + UTF8String.fromString(sketch.toString()) + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns human readable summary information about this sketch. + """, + examples = """ + Examples: + > SELECT LENGTH(_FUNC_(kll_sketch_agg_double(col))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchToStringDouble(child: Expression) extends KllSketchToStringBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchToStringDouble = + copy(child = newChild) + override def prettyName: String = "kll_sketch_to_string_double" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer)) + UTF8String.fromString(sketch.toString()) + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +/** This is a base class for the above expressions to reduce boilerplate. */ +abstract class KllSketchToStringBase + extends UnaryExpression + with CodegenFallback + with ImplicitCastInputTypes { + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + override def nullIntolerant: Boolean = true +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of items collected in the sketch. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg_bigint(col)) FROM VALUES (1), (2), (3), (4), (5) tab(col); + 5 + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchGetNBigint(child: Expression) extends KllSketchGetNBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchGetNBigint = + copy(child = newChild) + override def prettyName: String = "kll_sketch_get_n_bigint" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllLongsSketch.heapify(Memory.wrap(buffer)) + sketch.getN() + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of items collected in the sketch. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg_float(col)) FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col); + 5 + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchGetNFloat(child: Expression) extends KllSketchGetNBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchGetNFloat = + copy(child = newChild) + override def prettyName: String = "kll_sketch_get_n_float" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer)) + sketch.getN() + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of items collected in the sketch. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg_double(col)) FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col); + 5 + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchGetNDouble(child: Expression) extends KllSketchGetNBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchGetNDouble = + copy(child = newChild) + override def prettyName: String = "kll_sketch_get_n_double" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer)) + sketch.getN() + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +/** This is a base class for the above expressions to reduce boilerplate. */ +abstract class KllSketchGetNBase + extends UnaryExpression + with CodegenFallback + with ImplicitCastInputTypes { + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + override def nullIntolerant: Boolean = true +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(left, right) - Merges two sketch buffers together into one. + """, + examples = """ + Examples: + > SELECT LENGTH(kll_sketch_to_string_bigint(_FUNC_(kll_sketch_agg_bigint(col), kll_sketch_agg_bigint(col)))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchMergeBigint(left: Expression, right: Expression) extends KllSketchMergeBase { + override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) + override def prettyName: String = "kll_sketch_merge_bigint" + override def nullSafeEval(left: Any, right: Any): Any = { + try { + val leftBuffer = left.asInstanceOf[Array[Byte]] + val rightBuffer = right.asInstanceOf[Array[Byte]] + val leftSketch = KllLongsSketch.heapify(Memory.wrap(leftBuffer)) + val rightSketch = KllLongsSketch.wrap(Memory.wrap(rightBuffer)) + leftSketch.merge(rightSketch) + leftSketch.toByteArray + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchIncompatibleMergeError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(left, right) - Merges two sketch buffers together into one. + """, + examples = """ + Examples: + > SELECT LENGTH(kll_sketch_to_string_float(_FUNC_(kll_sketch_agg_float(col), kll_sketch_agg_float(col)))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchMergeFloat(left: Expression, right: Expression) extends KllSketchMergeBase { + override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) + override def prettyName: String = "kll_sketch_merge_float" + override def nullSafeEval(left: Any, right: Any): Any = { + try { + val leftBuffer = left.asInstanceOf[Array[Byte]] + val rightBuffer = right.asInstanceOf[Array[Byte]] + val leftSketch = KllFloatsSketch.heapify(Memory.wrap(leftBuffer)) + val rightSketch = KllFloatsSketch.wrap(Memory.wrap(rightBuffer)) + leftSketch.merge(rightSketch) + leftSketch.toByteArray + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchIncompatibleMergeError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(left, right) - Merges two sketch buffers together into one. + """, + examples = """ + Examples: + > SELECT LENGTH(kll_sketch_to_string_double(_FUNC_(kll_sketch_agg_double(col), kll_sketch_agg_double(col)))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchMergeDouble(left: Expression, right: Expression) extends KllSketchMergeBase { Review Comment: Great point @figure-shao, agreed this would be valuable. @dtenedor , if you're planning to extend the KLL functions to support partial aggregation semantics (allowing kll_sketch_merge_* to be used as proper aggregators), I'm happy to help with the implementation or take on the follow-up work if that’s useful. No pressure at all, just let me know what would be most helpful. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
