dtenedor commented on code in PR #52800: URL: https://github.com/apache/spark/pull/52800#discussion_r2670469504
########## sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala: ########## @@ -0,0 +1,671 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.datasketches.kll.{KllDoublesSketch, KllFloatsSketch, KllLongsSketch} +import org.apache.datasketches.memory.Memory + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, DataType, DoubleType, FloatType, LongType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns human readable summary information about this sketch. + """, + examples = """ + Examples: + > SELECT LENGTH(_FUNC_(kll_sketch_agg_bigint(col))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchToStringBigint(child: Expression) extends KllSketchToStringBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchToStringBigint = + copy(child = newChild) + override def prettyName: String = "kll_sketch_to_string_bigint" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllLongsSketch.heapify(Memory.wrap(buffer)) + UTF8String.fromString(sketch.toString()) + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns human readable summary information about this sketch. + """, + examples = """ + Examples: + > SELECT LENGTH(_FUNC_(kll_sketch_agg_float(col))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchToStringFloat(child: Expression) extends KllSketchToStringBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchToStringFloat = + copy(child = newChild) + override def prettyName: String = "kll_sketch_to_string_float" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer)) + UTF8String.fromString(sketch.toString()) + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns human readable summary information about this sketch. + """, + examples = """ + Examples: + > SELECT LENGTH(_FUNC_(kll_sketch_agg_double(col))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchToStringDouble(child: Expression) extends KllSketchToStringBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchToStringDouble = + copy(child = newChild) + override def prettyName: String = "kll_sketch_to_string_double" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer)) + UTF8String.fromString(sketch.toString()) + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +/** This is a base class for the above expressions to reduce boilerplate. */ +abstract class KllSketchToStringBase + extends UnaryExpression + with CodegenFallback + with ImplicitCastInputTypes { + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + override def nullIntolerant: Boolean = true +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of items collected in the sketch. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg_bigint(col)) FROM VALUES (1), (2), (3), (4), (5) tab(col); + 5 + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchGetNBigint(child: Expression) extends KllSketchGetNBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchGetNBigint = + copy(child = newChild) + override def prettyName: String = "kll_sketch_get_n_bigint" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllLongsSketch.heapify(Memory.wrap(buffer)) + sketch.getN() + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of items collected in the sketch. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg_float(col)) FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col); + 5 + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchGetNFloat(child: Expression) extends KllSketchGetNBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchGetNFloat = + copy(child = newChild) + override def prettyName: String = "kll_sketch_get_n_float" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer)) + sketch.getN() + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the number of items collected in the sketch. + """, + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg_double(col)) FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col); + 5 + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchGetNDouble(child: Expression) extends KllSketchGetNBase { + override protected def withNewChildInternal(newChild: Expression): KllSketchGetNDouble = + copy(child = newChild) + override def prettyName: String = "kll_sketch_get_n_double" + override def nullSafeEval(input: Any): Any = { + try { + val buffer = input.asInstanceOf[Array[Byte]] + val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer)) + sketch.getN() + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, e.getMessage) + } + } +} + +/** This is a base class for the above expressions to reduce boilerplate. */ +abstract class KllSketchGetNBase + extends UnaryExpression + with CodegenFallback + with ImplicitCastInputTypes { + override def dataType: DataType = LongType + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + override def nullIntolerant: Boolean = true +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(left, right) - Merges two sketch buffers together into one. + """, + examples = """ + Examples: + > SELECT LENGTH(kll_sketch_to_string_bigint(_FUNC_(kll_sketch_agg_bigint(col), kll_sketch_agg_bigint(col)))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchMergeBigint(left: Expression, right: Expression) extends KllSketchMergeBase { + override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) + override def prettyName: String = "kll_sketch_merge_bigint" + override def nullSafeEval(left: Any, right: Any): Any = { + try { + val leftBuffer = left.asInstanceOf[Array[Byte]] + val rightBuffer = right.asInstanceOf[Array[Byte]] + val leftSketch = KllLongsSketch.heapify(Memory.wrap(leftBuffer)) + val rightSketch = KllLongsSketch.wrap(Memory.wrap(rightBuffer)) + leftSketch.merge(rightSketch) + leftSketch.toByteArray + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchIncompatibleMergeError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(left, right) - Merges two sketch buffers together into one. + """, + examples = """ + Examples: + > SELECT LENGTH(kll_sketch_to_string_float(_FUNC_(kll_sketch_agg_float(col), kll_sketch_agg_float(col)))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchMergeFloat(left: Expression, right: Expression) extends KllSketchMergeBase { + override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = + copy(left = newLeft, right = newRight) + override def prettyName: String = "kll_sketch_merge_float" + override def nullSafeEval(left: Any, right: Any): Any = { + try { + val leftBuffer = left.asInstanceOf[Array[Byte]] + val rightBuffer = right.asInstanceOf[Array[Byte]] + val leftSketch = KllFloatsSketch.heapify(Memory.wrap(leftBuffer)) + val rightSketch = KllFloatsSketch.wrap(Memory.wrap(rightBuffer)) + leftSketch.merge(rightSketch) + leftSketch.toByteArray + } catch { + case e: Exception => + throw QueryExecutionErrors.kllSketchIncompatibleMergeError(prettyName, e.getMessage) + } + } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """ + _FUNC_(left, right) - Merges two sketch buffers together into one. + """, + examples = """ + Examples: + > SELECT LENGTH(kll_sketch_to_string_double(_FUNC_(kll_sketch_agg_double(col), kll_sketch_agg_double(col)))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col); + true + """, + group = "misc_funcs", + since = "4.1.0") +case class KllSketchMergeDouble(left: Expression, right: Expression) extends KllSketchMergeBase { Review Comment: @figure-shao that is a good point, we should add this. @cboumalh thanks for adding the implementation! Let's review and merge your PR. I can take a look there next. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
