This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch to_binary in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit be83fd5a692273eeedfa62822242787a878a575f Author: Jon <jmalkin.nore...@apache.org> AuthorDate: Mon Mar 24 22:53:38 2025 -0700 add function to cast sketches to BinaryType to handle spark weirdness --- python/src/datasketches_spark/common.py | 16 +++++ python/tests/kll_test.py | 13 +++- .../sql/datasketches/common/CastToBinary.scala | 71 ++++++++++++++++++++++ .../common/DatasketchesFunctionRegistry.scala | 7 +++ .../spark/sql/datasketches/common/functions.scala | 31 ++++++++++ .../spark/sql/datasketches/kll/KllTest.scala | 23 ++++++- 6 files changed, 158 insertions(+), 3 deletions(-) diff --git a/python/src/datasketches_spark/common.py b/python/src/datasketches_spark/common.py index d58924c..31bc935 100644 --- a/python/src/datasketches_spark/common.py +++ b/python/src/datasketches_spark/common.py @@ -17,11 +17,13 @@ from pyspark import SparkContext from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.utils import try_remote_functions from py4j.java_gateway import JavaClass from typing import Any, TypeVar, Union, Callable from functools import lru_cache from ._version import __version__ + import os from importlib.resources import files, as_file @@ -118,3 +120,17 @@ def _array_as_java_column(data: Union[list, tuple]) -> Column: col = _to_seq(sc, [_create_column_from_literal(x) for x in data]) return _invoke_function(_get_spark_functions_class(), "array", col)._jc #return _invoke_function(_get_spark_functions_class(), "array", _to_seq(sc, [_create_column_from_literal(x) for x in data]))._jc + + +_common_functions_class: JavaClass = None + +def _get_common_functions_class() -> JavaClass: + global _common_functions_class + if _common_functions_class is None: + _common_functions_class = _get_jvm_class("org.apache.spark.sql.datasketches.common.functions") + return _common_functions_class + + +@try_remote_functions +def cast_to_binary(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns(_get_common_functions_class(), "cast_to_binary", col) diff --git a/python/tests/kll_test.py b/python/tests/kll_test.py index a37c2f3..eb115fa 100644 --- a/python/tests/kll_test.py +++ b/python/tests/kll_test.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. -from pyspark.sql.types import StructType, StructField, DoubleType, IntegerType +from pyspark.sql.types import StructType, StructField, BinaryType, DoubleType, IntegerType -from datasketches import kll_doubles_sketch +#from datasketches import kll_doubles_sketch +from datasketches_spark.common import cast_to_binary from datasketches_spark.kll import * def test_kll_build(spark): @@ -45,6 +46,14 @@ def test_kll_build(spark): assert(sk.get_pmf([25000, 30000, 75000]) == result["pmf"]) assert(sk.get_cdf([20000, 50000, 95000], False) == result["cdf"]) + df_types = df_agg.select( + "sketch", + cast_to_binary("sketch").alias("asBinary") + ) + assert(df_types.schema["sketch"].dataType == KllDoublesSketchUDT()) + assert(df_types.schema["asBinary"].dataType == BinaryType()) + + def test_kll_merge(spark): n = 75 # stay in exact mode k = 200 diff --git a/src/main/scala/org/apache/spark/sql/datasketches/common/CastToBinary.scala b/src/main/scala/org/apache/spark/sql/datasketches/common/CastToBinary.scala new file mode 100644 index 0000000..8d7d514 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/datasketches/common/CastToBinary.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.datasketches.common + +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, BinaryType, DataType} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, ExpectsInputTypes} +import org.apache.spark.sql.catalyst.expressions.UnaryExpression +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeBlock, CodegenContext, ExprCode} + +@ExpressionDescription( + usage = """ + _FUNC_(expr) - Returns the input as a BinaryType (Array[Byte]). """ +/* , + examples = """ + Examples: + > SELECT _FUNC_(kll_sketch_agg(col)) FROM VALUES (1.0), (2.0), (3.0) tab(col); + 1.0 + """*/ + //group = "misc_funcs", +) +case class CastToBinary(sketchExpr: Expression) + extends UnaryExpression + with ExpectsInputTypes { + + override def child: Expression = sketchExpr + + override protected def withNewChildInternal(newChild: Expression): CastToBinary = { + copy(sketchExpr = newChild) + } + + override def prettyName: String = "sketch_to_binary_converter" + + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + + override def dataType: DataType = BinaryType + + override def nullSafeEval(input: Any): Any = { + input.asInstanceOf[Array[Byte]] + } + + override protected def nullSafeCodeGen(ctx: CodegenContext, ev: ExprCode, f: String => String): ExprCode = { + val sketchEval = child.genCode(ctx) + + val code = + s""" + |${sketchEval.code} + |final byte[] ${ev.value} = ${sketchEval.value}; + |final boolean ${ev.isNull} = ${sketchEval.isNull}; + """.stripMargin + ev.copy(code = CodeBlock(Seq(code), Seq.empty)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, c => s"($c)") + } +} diff --git a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala index bdc8235..29207b0 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/common/DatasketchesFunctionRegistry.scala @@ -58,3 +58,10 @@ trait DatasketchesFunctionRegistry { (name, (expressionInfo, builder)) } } + +// object for common functions +object CommonFunctionRegistry extends DatasketchesFunctionRegistry { + override val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( + expression[CastToBinary]("cast_to_binary"), + ) +} diff --git a/src/main/scala/org/apache/spark/sql/datasketches/common/functions.scala b/src/main/scala/org/apache/spark/sql/datasketches/common/functions.scala new file mode 100644 index 0000000..ff8fcf0 --- /dev/null +++ b/src/main/scala/org/apache/spark/sql/datasketches/common/functions.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.datasketches.common + +import org.apache.spark.sql.Column + +object functions extends DatasketchesScalaFunctionBase { + + def cast_to_binary(expr: Column): Column = withExpr { + new CastToBinary(expr.expr) + } + + def sketch_to_binary(columnName: String): Column = { + cast_to_binary(Column(columnName)) + } +} diff --git a/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala b/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala index 0a17a9d..acb8f0d 100644 --- a/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala +++ b/src/test/scala/org/apache/spark/sql/datasketches/kll/KllTest.scala @@ -25,7 +25,8 @@ import org.apache.spark.sql.types.{StructType, StructField, IntegerType, BinaryT import org.apache.datasketches.kll.KllDoublesSketch import org.apache.spark.sql.datasketches.kll.functions._ import org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType -import org.apache.spark.sql.datasketches.common.SparkSessionManager +import org.apache.spark.sql.datasketches.common.{SparkSessionManager, CommonFunctionRegistry} +import org.apache.spark.sql.datasketches.common.functions.cast_to_binary class KllTest extends SparkSessionManager { import spark.implicits._ @@ -116,11 +117,16 @@ class KllTest extends SparkSessionManager { val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0) compareArrays(cdf_excl, pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray) + + val resultSchema = sketchDf.select($"sketch", cast_to_binary($"sketch").as("asBinary")).schema + assert(resultSchema.apply("sketch").dataType.equals(KllDoublesSketchType)) + assert(resultSchema.apply("asBinary").dataType.equals(BinaryType)) } test("Kll Doubles Sketch via SQL") { // register KLL functions KllFunctionRegistry.registerFunctions(spark) + CommonFunctionRegistry.registerFunctions(spark) val n = 100 val data = (for (i <- 1 to n) yield i.toDouble).toDF("value") @@ -167,6 +173,21 @@ class KllTest extends SparkSessionManager { val cdf_excl = Array[Double](0.2, 0.49, 1.0, 1.0) compareArrays(cdf_excl, pmfCdfResult.getAs[Seq[Double]]("cdf_exclusive").toArray) + + val schemaCheckResult = spark.sql( + s""" + |SELECT + | kll_sketch_double_agg_build(value, 200) AS sketch, + | cast_to_binary(kll_sketch_double_agg_build(value, 200)) AS asBinary + |FROM + | data_table + """.stripMargin + ) + + val resultSchema = schemaCheckResult.schema + assert(resultSchema.apply("sketch").dataType.equals(KllDoublesSketchType)) + assert(resultSchema.apply("asBinary").dataType.equals(BinaryType)) + } test("KLL Doubles Merge via Scala") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datasketches.apache.org For additional commands, e-mail: commits-h...@datasketches.apache.org