This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch to_binary
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit be83fd5a692273eeedfa62822242787a878a575f
Author: Jon <jmalkin.nore...@apache.org>
AuthorDate: Mon Mar 24 22:53:38 2025 -0700

    add function to cast sketches to BinaryType to handle spark weirdness
---
 python/src/datasketches_spark/common.py            | 16 +++++
 python/tests/kll_test.py                           | 13 +++-
 .../sql/datasketches/common/CastToBinary.scala     | 71 ++++++++++++++++++++++
 .../common/DatasketchesFunctionRegistry.scala      |  7 +++
 .../spark/sql/datasketches/common/functions.scala  | 31 ++++++++++
 .../spark/sql/datasketches/kll/KllTest.scala       | 23 ++++++-
 6 files changed, 158 insertions(+), 3 deletions(-)

diff --git a/python/src/datasketches_spark/common.py 
b/python/src/datasketches_spark/common.py
index d58924c..31bc935 100644
--- a/python/src/datasketches_spark/common.py
+++ b/python/src/datasketches_spark/common.py
@@ -17,11 +17,13 @@
 
 from pyspark import SparkContext
 from pyspark.sql.column import Column, _to_java_column, _to_seq, 
_create_column_from_literal
+from pyspark.sql.utils import try_remote_functions
 from py4j.java_gateway import JavaClass
 from typing import Any, TypeVar, Union, Callable
 from functools import lru_cache
 from ._version import __version__
 
+
 import os
 from importlib.resources import files, as_file
 
@@ -118,3 +120,17 @@ def _array_as_java_column(data: Union[list, tuple]) -> 
Column:
     col = _to_seq(sc, [_create_column_from_literal(x) for x in data])
     return _invoke_function(_get_spark_functions_class(), "array", col)._jc
     #return _invoke_function(_get_spark_functions_class(), "array", 
_to_seq(sc, [_create_column_from_literal(x) for x in data]))._jc
+
+
+_common_functions_class: JavaClass = None
+
+def _get_common_functions_class() -> JavaClass:
+    global _common_functions_class
+    if _common_functions_class is None:
+        _common_functions_class = 
_get_jvm_class("org.apache.spark.sql.datasketches.common.functions")
+    return _common_functions_class
+
+
+@try_remote_functions
+def cast_to_binary(col: "ColumnOrName") -> Column:
+    return _invoke_function_over_columns(_get_common_functions_class(), 
"cast_to_binary", col)
diff --git a/python/tests/kll_test.py b/python/tests/kll_test.py
index a37c2f3..eb115fa 100644
--- a/python/tests/kll_test.py
+++ b/python/tests/kll_test.py
@@ -15,9 +15,10 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType
+from pyspark.sql.types import StructType, StructField, BinaryType, DoubleType, 
IntegerType
 
-from datasketches import kll_doubles_sketch
+#from datasketches import kll_doubles_sketch
+from datasketches_spark.common import cast_to_binary
 from datasketches_spark.kll import *
 
 def test_kll_build(spark):
@@ -45,6 +46,14 @@ def test_kll_build(spark):
   assert(sk.get_pmf([25000, 30000, 75000]) == result["pmf"])
   assert(sk.get_cdf([20000, 50000, 95000], False) == result["cdf"])
 
+  df_types = df_agg.select(
+    "sketch",
+    cast_to_binary("sketch").alias("asBinary")
+  )
+  assert(df_types.schema["sketch"].dataType == KllDoublesSketchUDT())
+  assert(df_types.schema["asBinary"].dataType == BinaryType())
+
+
 def test_kll_merge(spark):
   n = 75 # stay in exact mode
   k = 200
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/common/CastToBinary.scala 
b/src/main/scala/org/apache/spark/sql/datasketches/common/CastToBinary.scala
new file mode 100644
index 0000000..8d7d514
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/datasketches/common/CastToBinary.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.datasketches.common
+
+import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BinaryType, 
DataType}
+import org.apache.spark.sql.catalyst.expressions.{Expression, 
ExpressionDescription, ExpectsInputTypes}
+import org.apache.spark.sql.catalyst.expressions.UnaryExpression
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, 
CodegenContext, ExprCode}
+
+@ExpressionDescription(
+  usage = """
+    _FUNC_(expr) - Returns the input as a BinaryType (Array[Byte]). """
+/*    ,
+  examples = """
+    Examples:
+      > SELECT _FUNC_(kll_sketch_agg(col)) FROM VALUES (1.0), (2.0), (3.0) 
tab(col);
+       1.0
+  """*/
+  //group = "misc_funcs",
+)
+case class CastToBinary(sketchExpr: Expression)
+ extends UnaryExpression
+ with ExpectsInputTypes {
+
+  override def child: Expression = sketchExpr
+
+  override protected def withNewChildInternal(newChild: Expression): 
CastToBinary = {
+    copy(sketchExpr = newChild)
+  }
+
+  override def prettyName: String = "sketch_to_binary_converter"
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
+
+  override def dataType: DataType = BinaryType
+
+  override def nullSafeEval(input: Any): Any = {
+    input.asInstanceOf[Array[Byte]]
+  }
+
+  override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: 
String => String): ExprCode = {
+    val sketchEval = child.genCode(ctx)
+
+    val code =
+      s"""
+         |${sketchEval.code}
+         |final byte[] ${ev.value} = ${sketchEval.value};
+         |final boolean ${ev.isNull} = ${sketchEval.isNull};
+       """.stripMargin
+    ev.copy(code = CodeBlock(Seq(code), Seq.empty))
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, c => s"($c)")
+  }
+}
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala
index bdc8235..29207b0 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala
@@ -58,3 +58,10 @@ trait DatasketchesFunctionRegistry {
     (name, (expressionInfo, builder))
   }
 }
+
+// object for common functions
+object CommonFunctionRegistry extends DatasketchesFunctionRegistry {
+  override val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = 
Map(
+    expression[CastToBinary]("cast_to_binary"),
+  )
+}
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/common/functions.scala 
b/src/main/scala/org/apache/spark/sql/datasketches/common/functions.scala
new file mode 100644
index 0000000..ff8fcf0
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/datasketches/common/functions.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.datasketches.common
+
+import org.apache.spark.sql.Column
+
+object functions extends DatasketchesScalaFunctionBase {
+
+  def cast_to_binary(expr: Column): Column = withExpr {
+    new CastToBinary(expr.expr)
+  }
+
+  def sketch_to_binary(columnName: String): Column = {
+    cast_to_binary(Column(columnName))
+  }
+}
diff --git a/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala 
b/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala
index 0a17a9d..acb8f0d 100644
--- a/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala
+++ b/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala
@@ -25,7 +25,8 @@ import org.apache.spark.sql.types.{StructType, StructField, 
IntegerType, BinaryT
 import org.apache.datasketches.kll.KllDoublesSketch
 import org.apache.spark.sql.datasketches.kll.functions._
 import org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType
-import org.apache.spark.sql.datasketches.common.SparkSessionManager
+import org.apache.spark.sql.datasketches.common.{SparkSessionManager, 
CommonFunctionRegistry}
+import org.apache.spark.sql.datasketches.common.functions.cast_to_binary
 
 class KllTest extends SparkSessionManager {
   import spark.implicits._
@@ -116,11 +117,16 @@ class KllTest extends SparkSessionManager {
 
     val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0)
     compareArrays(cdf_excl, 
pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray)
+
+    val resultSchema = sketchDf.select($"sketch", 
cast_to_binary($"sketch").as("asBinary")).schema
+    assert(resultSchema.apply("sketch").dataType.equals(KllDoublesSketchType))
+    assert(resultSchema.apply("asBinary").dataType.equals(BinaryType))
   }
 
   test("Kll Doubles Sketch via SQL") {
     // register KLL functions
     KllFunctionRegistry.registerFunctions(spark)
+    CommonFunctionRegistry.registerFunctions(spark)
 
     val n = 100
     val data = (for (i <- 1 to n) yield i.toDouble).toDF("value")
@@ -167,6 +173,21 @@ class KllTest extends SparkSessionManager {
 
     val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0)
     compareArrays(cdf_excl, 
pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray)
+
+    val schemaCheckResult = spark.sql(
+      s"""
+      |SELECT
+      |  kll_sketch_double_agg_build(value, 200) AS sketch,
+      |  cast_to_binary(kll_sketch_double_agg_build(value, 200)) AS asBinary
+      |FROM
+      |  data_table
+      """.stripMargin
+    )
+
+    val resultSchema = schemaCheckResult.schema
+    assert(resultSchema.apply("sketch").dataType.equals(KllDoublesSketchType))
+    assert(resultSchema.apply("asBinary").dataType.equals(BinaryType))
+
   }
 
   test("KLL Doubles Merge via Scala") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datasketches.apache.org
For additional commands, e-mail: commits-h...@datasketches.apache.org

Reply via email to