figure-shao commented on code in PR #52800:
URL: https://github.com/apache/spark/pull/52800#discussion_r2612501151


##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/kllExpressions.scala:
##########
@@ -0,0 +1,671 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.datasketches.kll.{KllDoublesSketch, KllFloatsSketch, 
KllLongsSketch}
+import org.apache.datasketches.memory.Memory
+
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types.{AbstractDataType, ArrayType, BinaryType, 
DataType, DoubleType, FloatType, LongType, StringType, TypeCollection}
+import org.apache.spark.unsafe.types.UTF8String
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns human readable summary information about this 
sketch.
+  """,
+  examples = """
+    Examples:
+      > SELECT LENGTH(_FUNC_(kll_sketch_agg_bigint(col))) > 0 FROM VALUES (1), 
(2), (3), (4), (5) tab(col);
+       true
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchToStringBigint(child: Expression) extends 
KllSketchToStringBase {
+  override protected def withNewChildInternal(newChild: Expression): 
KllSketchToStringBigint =
+    copy(child = newChild)
+  override def prettyName: String = "kll_sketch_to_string_bigint"
+  override def nullSafeEval(input: Any): Any = {
+    try {
+      val buffer = input.asInstanceOf[Array[Byte]]
+      val sketch = KllLongsSketch.heapify(Memory.wrap(buffer))
+      UTF8String.fromString(sketch.toString())
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns human readable summary information about this 
sketch.
+  """,
+  examples = """
+    Examples:
+      > SELECT LENGTH(_FUNC_(kll_sketch_agg_float(col))) > 0 FROM VALUES 
(CAST(1.0 AS FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS 
FLOAT)), (CAST(5.0 AS FLOAT)) tab(col);
+       true
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchToStringFloat(child: Expression) extends 
KllSketchToStringBase {
+  override protected def withNewChildInternal(newChild: Expression): 
KllSketchToStringFloat =
+    copy(child = newChild)
+  override def prettyName: String = "kll_sketch_to_string_float"
+  override def nullSafeEval(input: Any): Any = {
+    try {
+      val buffer = input.asInstanceOf[Array[Byte]]
+      val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer))
+      UTF8String.fromString(sketch.toString())
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns human readable summary information about this 
sketch.
+  """,
+  examples = """
+    Examples:
+      > SELECT LENGTH(_FUNC_(kll_sketch_agg_double(col))) > 0 FROM VALUES 
(CAST(1.0 AS DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 
AS DOUBLE)), (CAST(5.0 AS DOUBLE)) tab(col);
+       true
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchToStringDouble(child: Expression) extends 
KllSketchToStringBase {
+  override protected def withNewChildInternal(newChild: Expression): 
KllSketchToStringDouble =
+    copy(child = newChild)
+  override def prettyName: String = "kll_sketch_to_string_double"
+  override def nullSafeEval(input: Any): Any = {
+    try {
+      val buffer = input.asInstanceOf[Array[Byte]]
+      val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer))
+      UTF8String.fromString(sketch.toString())
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+/** This is a base class for the above expressions to reduce boilerplate. */
+abstract class KllSketchToStringBase
+    extends UnaryExpression
+        with CodegenFallback
+        with ImplicitCastInputTypes {
+  override def dataType: DataType = StringType
+  override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+  override def nullIntolerant: Boolean = true
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns the number of items collected in the sketch.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(kll_sketch_agg_bigint(col)) FROM VALUES (1), (2), (3), 
(4), (5) tab(col);
+       5
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchGetNBigint(child: Expression) extends KllSketchGetNBase {
+  override protected def withNewChildInternal(newChild: Expression): 
KllSketchGetNBigint =
+    copy(child = newChild)
+  override def prettyName: String = "kll_sketch_get_n_bigint"
+  override def nullSafeEval(input: Any): Any = {
+    try {
+      val buffer = input.asInstanceOf[Array[Byte]]
+      val sketch = KllLongsSketch.heapify(Memory.wrap(buffer))
+      sketch.getN()
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns the number of items collected in the sketch.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(kll_sketch_agg_float(col)) FROM VALUES (CAST(1.0 AS 
FLOAT)), (CAST(2.0 AS FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), 
(CAST(5.0 AS FLOAT)) tab(col);
+       5
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchGetNFloat(child: Expression) extends KllSketchGetNBase {
+  override protected def withNewChildInternal(newChild: Expression): 
KllSketchGetNFloat =
+    copy(child = newChild)
+  override def prettyName: String = "kll_sketch_get_n_float"
+  override def nullSafeEval(input: Any): Any = {
+    try {
+      val buffer = input.asInstanceOf[Array[Byte]]
+      val sketch = KllFloatsSketch.heapify(Memory.wrap(buffer))
+      sketch.getN()
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns the number of items collected in the sketch.
+  """,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(kll_sketch_agg_double(col)) FROM VALUES (CAST(1.0 AS 
DOUBLE)), (CAST(2.0 AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), 
(CAST(5.0 AS DOUBLE)) tab(col);
+       5
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchGetNDouble(child: Expression) extends KllSketchGetNBase {
+  override protected def withNewChildInternal(newChild: Expression): 
KllSketchGetNDouble =
+    copy(child = newChild)
+  override def prettyName: String = "kll_sketch_get_n_double"
+  override def nullSafeEval(input: Any): Any = {
+    try {
+      val buffer = input.asInstanceOf[Array[Byte]]
+      val sketch = KllDoublesSketch.heapify(Memory.wrap(buffer))
+      sketch.getN()
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchInvalidInputError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+/** This is a base class for the above expressions to reduce boilerplate. */
+abstract class KllSketchGetNBase
+    extends UnaryExpression
+        with CodegenFallback
+        with ImplicitCastInputTypes {
+  override def dataType: DataType = LongType
+  override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType)
+  override def nullIntolerant: Boolean = true
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(left, right) - Merges two sketch buffers together into one.
+  """,
+  examples = """
+    Examples:
+      > SELECT 
LENGTH(kll_sketch_to_string_bigint(_FUNC_(kll_sketch_agg_bigint(col), 
kll_sketch_agg_bigint(col)))) > 0 FROM VALUES (1), (2), (3), (4), (5) tab(col);
+       true
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchMergeBigint(left: Expression, right: Expression) extends 
KllSketchMergeBase {
+  override def withNewChildrenInternal(newLeft: Expression, newRight: 
Expression): Expression =
+    copy(left = newLeft, right = newRight)
+  override def prettyName: String = "kll_sketch_merge_bigint"
+  override def nullSafeEval(left: Any, right: Any): Any = {
+    try {
+      val leftBuffer = left.asInstanceOf[Array[Byte]]
+      val rightBuffer = right.asInstanceOf[Array[Byte]]
+      val leftSketch = KllLongsSketch.heapify(Memory.wrap(leftBuffer))
+      val rightSketch = KllLongsSketch.wrap(Memory.wrap(rightBuffer))
+      leftSketch.merge(rightSketch)
+      leftSketch.toByteArray
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchIncompatibleMergeError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(left, right) - Merges two sketch buffers together into one.
+  """,
+  examples = """
+    Examples:
+      > SELECT 
LENGTH(kll_sketch_to_string_float(_FUNC_(kll_sketch_agg_float(col), 
kll_sketch_agg_float(col)))) > 0 FROM VALUES (CAST(1.0 AS FLOAT)), (CAST(2.0 AS 
FLOAT)), (CAST(3.0 AS FLOAT)), (CAST(4.0 AS FLOAT)), (CAST(5.0 AS FLOAT)) 
tab(col);
+       true
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchMergeFloat(left: Expression, right: Expression) extends 
KllSketchMergeBase {
+  override def withNewChildrenInternal(newLeft: Expression, newRight: 
Expression): Expression =
+    copy(left = newLeft, right = newRight)
+  override def prettyName: String = "kll_sketch_merge_float"
+  override def nullSafeEval(left: Any, right: Any): Any = {
+    try {
+      val leftBuffer = left.asInstanceOf[Array[Byte]]
+      val rightBuffer = right.asInstanceOf[Array[Byte]]
+      val leftSketch = KllFloatsSketch.heapify(Memory.wrap(leftBuffer))
+      val rightSketch = KllFloatsSketch.wrap(Memory.wrap(rightBuffer))
+      leftSketch.merge(rightSketch)
+      leftSketch.toByteArray
+    } catch {
+      case e: Exception =>
+        throw QueryExecutionErrors.kllSketchIncompatibleMergeError(prettyName, 
e.getMessage)
+    }
+  }
+}
+
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = """
+    _FUNC_(left, right) - Merges two sketch buffers together into one.
+  """,
+  examples = """
+    Examples:
+      > SELECT 
LENGTH(kll_sketch_to_string_double(_FUNC_(kll_sketch_agg_double(col), 
kll_sketch_agg_double(col)))) > 0 FROM VALUES (CAST(1.0 AS DOUBLE)), (CAST(2.0 
AS DOUBLE)), (CAST(3.0 AS DOUBLE)), (CAST(4.0 AS DOUBLE)), (CAST(5.0 AS 
DOUBLE)) tab(col);
+       true
+  """,
+  group = "misc_funcs",
+  since = "4.1.0")
+case class KllSketchMergeDouble(left: Expression, right: Expression) extends 
KllSketchMergeBase {

Review Comment:
   Great feature! @dtenedor
   
   What should I use to aggregate KLL sketches together in a group by? Being 
aggregated in a group by is the biggest advantage of having sketches, they act 
as partial aggregates and contributes to a complete aggregate at query time. 
But this is not supported.
   
   ```
   SELECT
     kll_sketch_merge_double(kll)
   FROM quantile.table
   GROUP BY name
   ```
   
   yields an error
   
   ```
   
[[WRONG_NUM_ARGS.WITHOUT_SUGGESTION](https://learn.microsoft.com/azure/databricks/error-messages/wrong-num-args-error-class#without_suggestion)]
 The `kll_sketch_merge_double` requires 2 parameters but the actual number is 
1. Please, refer to 
'https://spark.apache.org/docs/latest/sql-ref-functions.html' for a fix.
   ```
   
   (Azure Databricks runtime 18.0, spark 2.13)
   
   I believe it's a easy lift to make this usable as a group by aggregator. 
Please help!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to