This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch python
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git


The following commit(s) were added to refs/heads/python by this push:
     new 71bc84e  final(?) changes for baseline kll in pyspark, including 
useful tests
71bc84e is described below

commit 71bc84ef579594e425d980af2c411dc37970320c
Author: Jon Malkin <[email protected]>
AuthorDate: Thu Feb 13 19:55:36 2025 -0800

    final(?) changes for baseline kll in pyspark, including useful tests
---
 python/pyproject.toml                              |   3 +-
 python/src/datasketches_spark/kll.py               |  29 +++++-
 python/{pyproject.toml => tests/conftest.py}       |  46 ++++------
 python/tests/kll_test.py                           | 101 +++++++++++++--------
 .../kll/types/KllDoublesSketchType.scala           |   4 +
 5 files changed, 114 insertions(+), 69 deletions(-)

diff --git a/python/pyproject.toml b/python/pyproject.toml
index 5f41e9e..beac16e 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -30,7 +30,8 @@ license = { text = "Apache License 2.0" }
 readme = "README.md"
 requires-python = ">=3.8"
 dependencies = [
-  "pyspark"
+  "pyspark",
+  "datasketches"
 ]
 
 [tool.setuptools]
diff --git a/python/src/datasketches_spark/kll.py 
b/python/src/datasketches_spark/kll.py
index f89ef16..d2d11af 100644
--- a/python/src/datasketches_spark/kll.py
+++ b/python/src/datasketches_spark/kll.py
@@ -17,10 +17,13 @@
 
 from typing import List, Optional, Tuple, Union
 from py4j.java_gateway import JavaClass
-from pyspark.sql.column import Column, _to_java_column, _to_seq, 
_create_column_from_literal
+from pyspark.sql.column import Column, _to_java_column # possibly fragile
 from pyspark.sql.functions import lit
 from pyspark.sql.utils import try_remote_functions
 
+from pyspark.sql.types import UserDefinedType, BinaryType
+from datasketches import kll_doubles_sketch
+
 from .common import (
                      ColumnOrName,
                      _invoke_function,
@@ -37,6 +40,30 @@ def _get_kll_functions_class() -> JavaClass:
         _kll_functions_class = 
_get_jvm_class("org.apache.spark.sql.datasketches.kll.functions")
     return _kll_functions_class
 
+class KllDoublesSketchUDT(UserDefinedType):
+    """UDT to translate kll_doubles_sketch to/from spark"""
+
+    @classmethod
+    def sqlType(cls):
+        return BinaryType()
+
+    def serialize(self, sketch: kll_doubles_sketch) -> bytes:
+        if sketch is None:
+            return None
+        return sketch.serialize()
+
+    def deserialize(self, data: bytes) -> kll_doubles_sketch:
+        if data is None:
+            return None
+        return kll_doubles_sketch.deserialize(bytes(data))
+
+    @classmethod
+    def module(cls):
+        return "datasketches"
+
+    @classmethod
+    def scalaUDT(cls):
+        return "org.apache.spark.sql.datasketches.kll.KllDoublesSketchType"
 
 @try_remote_functions
 def kll_sketch_double_agg_build(col: "ColumnOrName", k: Optional[Union[int, 
Column]] = None) -> Column:
diff --git a/python/pyproject.toml b/python/tests/conftest.py
similarity index 51%
copy from python/pyproject.toml
copy to python/tests/conftest.py
index 5f41e9e..5a3fe2d 100644
--- a/python/pyproject.toml
+++ b/python/tests/conftest.py
@@ -15,33 +15,21 @@
 # specific language governing permissions and limitations
 # under the License.
 
-[build-system]
-requires = ["setuptools", "wheel"]
-build-backend = "setuptools.build_meta"
+import pytest
+from pyspark.sql import SparkSession
+from datasketches_spark import get_dependency_classpath
 
-[project]
-name = "datasketches_spark"
-dynamic = ["version"]
-description = "The Apache DataSketches Library for Python"
-authors = [
-  { name = "Apache Software Foundation", email = 
"[email protected]" }
-]
-license = { text = "Apache License 2.0" }
-readme = "README.md"
-requires-python = ">=3.8"
-dependencies = [
-  "pyspark"
-]
-
-[tool.setuptools]
-package-dir = { "" = "src" }
-
-[tool.setuptools.dynamic]
-version = { file = "src/datasketches_spark/version.txt" }
-
-[tool.setuptools.package-data]
-datasketches_spark = ["version.txt", "deps/*"]
-
-[tool.cibuildwheel]
-build-verbosity = 0  # options: 1, 2, or 3
-skip = ["cp36-*", "cp37-*", "cp38-*", "pp*", "*-win32"]
\ No newline at end of file
[email protected](scope="session")
+def spark():
+    spark = (
+        SparkSession.builder
+        .appName("test")
+        .master("local[*]")
+        .config("spark.driver.userClassPathFirst", "true")
+        .config("spark.executor.userClassPathFirst", "true")
+        .config("spark.driver.extraClassPath", get_dependency_classpath())
+        .config("spark.executor.extraClassPath", get_dependency_classpath())
+        .getOrCreate()
+    )
+    yield spark
+    spark.stop()
diff --git a/python/tests/kll_test.py b/python/tests/kll_test.py
index c142049..0ca850d 100644
--- a/python/tests/kll_test.py
+++ b/python/tests/kll_test.py
@@ -1,43 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
 
-import unittest
 from pyspark.sql.types import StructType, StructField, DoubleType
-from pyspark.sql.session import SparkSession
 
-from datasketches_spark import get_dependency_classpath
+from datasketches import kll_doubles_sketch
 from datasketches_spark.kll import *
 
-class PySparkBase(unittest.TestCase):
-  @classmethod
-  def setUpClass(cls):
-    cls.spark = (
-          SparkSession.builder
-            .appName("test")
-            .master("local[1]")
-            .config("spark.driver.userClassPathFirst", "true")
-            .config("spark.executor.userClassPathFirst", "true")
-            .config("spark.driver.extraClassPath", get_dependency_classpath())
-            .config("spark.executor.extraClassPath", 
get_dependency_classpath())
-            .getOrCreate()
-        )
-
-  @classmethod
-  def tearDownClass(cls):
-    cls.spark.stop()
-
-class TestKll(PySparkBase):
-  def test_kll(self):
-    spark = self.spark
-
-    # Create a DataFrame
-    n = 100000
-    data = [(float(i),) for i in range(1, n + 1)]
-    schema = StructType([StructField("value", DoubleType(), True)])
-    df = spark.createDataFrame(data, schema)
-    df_agg = df.agg(kll_sketch_double_agg_build("value", 160).alias("sketch"))
-    df_agg.show()
-
-    df_agg.select(
-      kll_sketch_double_get_min("sketch").alias("min"),
-      kll_sketch_double_get_max("sketch").alias("max"),
-      kll_sketch_double_get_pmf("sketch", [25000, 30000, 75000]).alias("pmf")
-    ).show()
+def test_kll_build(spark):
+  n = 100000
+  k = 160
+  data = [(float(i),) for i in range(1, n + 1)]
+  schema = StructType([StructField("value", DoubleType(), True)])
+  df = spark.createDataFrame(data, schema)
+
+  df_agg = df.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+
+  result = df_agg.select(
+    "sketch",
+    kll_sketch_double_get_min("sketch").alias("min"),
+    kll_sketch_double_get_max("sketch").alias("max"),
+    kll_sketch_double_get_pmf("sketch", [25000, 30000, 75000]).alias("pmf"),
+    kll_sketch_double_get_cdf("sketch", [20000, 50000, 95000], 
False).alias("cdf")
+  ).first()
+  sk = result["sketch"]
+
+  assert(sk.n == n)
+  assert(sk.k == k)
+  assert(result["min"] == sk.get_min_value())
+  assert(result["max"] == sk.get_max_value())
+  assert(sk.get_pmf([25000, 30000, 75000]) == result["pmf"])
+  assert(sk.get_cdf([20000, 50000, 95000], False) == result["cdf"])
+
+def test_kll_merge(spark):
+  n = 75 # stay in exact mode
+  k = 200
+  data1 = [(float(i),) for i in range(1, n + 1)]
+  data2 = [(float(i),) for i in range(n + 1, 2 * n + 1)]
+  schema = StructType([StructField("value", DoubleType(), True)])
+  df1 = spark.createDataFrame(data1, schema)
+  df2 = spark.createDataFrame(data2, schema)
+
+  df_agg1 = df1.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+  df_agg2 = df2.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+
+  result = df_agg1.union(df_agg2).select(
+    kll_sketch_double_agg_merge("sketch").alias("sketch")
+  ).first()
+  sk = result["sketch"]
+
+  assert(sk.n == 2 * n)
+  assert(sk.k == k)
+  assert(sk.get_min_value() == 1.0)
+  assert(sk.get_max_value() == 2 * n)
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
index a058190..59299d2 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
@@ -25,6 +25,10 @@ import org.apache.spark.sql.types.{DataType, DataTypes, 
UDTRegistration, UserDef
 class KllDoublesSketchType extends UserDefinedType[KllDoublesSketch] with 
Serializable {
   override def sqlType: DataType = DataTypes.BinaryType
 
+  override def serializedPyClass: String = "bytes"
+
+  override def pyUDT: String = "datasketches_spark.KllDoublesSketchUDT"
+
   override def serialize(wrapper: KllDoublesSketch): Array[Byte] = {
     wrapper.toByteArray
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to