This is an automated email from the ASF dual-hosted git repository.
jmalkin pushed a commit to branch python
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
The following commit(s) were added to refs/heads/python by this push:
new 71bc84e final(?) changes for baseline kll in pyspark, including
useful tests
71bc84e is described below
commit 71bc84ef579594e425d980af2c411dc37970320c
Author: Jon Malkin <[email protected]>
AuthorDate: Thu Feb 13 19:55:36 2025 -0800
final(?) changes for baseline kll in pyspark, including useful tests
---
python/pyproject.toml | 3 +-
python/src/datasketches_spark/kll.py | 29 +++++-
python/{pyproject.toml => tests/conftest.py} | 46 ++++------
python/tests/kll_test.py | 101 +++++++++++++--------
.../kll/types/KllDoublesSketchType.scala | 4 +
5 files changed, 114 insertions(+), 69 deletions(-)
diff --git a/python/pyproject.toml b/python/pyproject.toml
index 5f41e9e..beac16e 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -30,7 +30,8 @@ license = { text = "Apache License 2.0" }
readme = "README.md"
requires-python = ">=3.8"
dependencies = [
- "pyspark"
+ "pyspark",
+ "datasketches"
]
[tool.setuptools]
diff --git a/python/src/datasketches_spark/kll.py
b/python/src/datasketches_spark/kll.py
index f89ef16..d2d11af 100644
--- a/python/src/datasketches_spark/kll.py
+++ b/python/src/datasketches_spark/kll.py
@@ -17,10 +17,13 @@
from typing import List, Optional, Tuple, Union
from py4j.java_gateway import JavaClass
-from pyspark.sql.column import Column, _to_java_column, _to_seq,
_create_column_from_literal
+from pyspark.sql.column import Column, _to_java_column # possibly fragile
from pyspark.sql.functions import lit
from pyspark.sql.utils import try_remote_functions
+from pyspark.sql.types import UserDefinedType, BinaryType
+from datasketches import kll_doubles_sketch
+
from .common import (
ColumnOrName,
_invoke_function,
@@ -37,6 +40,30 @@ def _get_kll_functions_class() -> JavaClass:
_kll_functions_class =
_get_jvm_class("org.apache.spark.sql.datasketches.kll.functions")
return _kll_functions_class
+class KllDoublesSketchUDT(UserDefinedType):
+ """UDT to translate kll_doubles_sketch to/from spark"""
+
+ @classmethod
+ def sqlType(cls):
+ return BinaryType()
+
+ def serialize(self, sketch: kll_doubles_sketch) -> bytes:
+ if sketch is None:
+ return None
+ return sketch.serialize()
+
+ def deserialize(self, data: bytes) -> kll_doubles_sketch:
+ if data is None:
+ return None
+ return kll_doubles_sketch.deserialize(bytes(data))
+
+ @classmethod
+ def module(cls):
+ return "datasketches"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.sql.datasketches.kll.KllDoublesSketchType"
@try_remote_functions
def kll_sketch_double_agg_build(col: "ColumnOrName", k: Optional[Union[int,
Column]] = None) -> Column:
diff --git a/python/pyproject.toml b/python/tests/conftest.py
similarity index 51%
copy from python/pyproject.toml
copy to python/tests/conftest.py
index 5f41e9e..5a3fe2d 100644
--- a/python/pyproject.toml
+++ b/python/tests/conftest.py
@@ -15,33 +15,21 @@
# specific language governing permissions and limitations
# under the License.
-[build-system]
-requires = ["setuptools", "wheel"]
-build-backend = "setuptools.build_meta"
+import pytest
+from pyspark.sql import SparkSession
+from datasketches_spark import get_dependency_classpath
-[project]
-name = "datasketches_spark"
-dynamic = ["version"]
-description = "The Apache DataSketches Library for Python"
-authors = [
- { name = "Apache Software Foundation", email =
"[email protected]" }
-]
-license = { text = "Apache License 2.0" }
-readme = "README.md"
-requires-python = ">=3.8"
-dependencies = [
- "pyspark"
-]
-
-[tool.setuptools]
-package-dir = { "" = "src" }
-
-[tool.setuptools.dynamic]
-version = { file = "src/datasketches_spark/version.txt" }
-
-[tool.setuptools.package-data]
-datasketches_spark = ["version.txt", "deps/*"]
-
-[tool.cibuildwheel]
-build-verbosity = 0 # options: 1, 2, or 3
-skip = ["cp36-*", "cp37-*", "cp38-*", "pp*", "*-win32"]
\ No newline at end of file
[email protected](scope="session")
+def spark():
+ spark = (
+ SparkSession.builder
+ .appName("test")
+ .master("local[*]")
+ .config("spark.driver.userClassPathFirst", "true")
+ .config("spark.executor.userClassPathFirst", "true")
+ .config("spark.driver.extraClassPath", get_dependency_classpath())
+ .config("spark.executor.extraClassPath", get_dependency_classpath())
+ .getOrCreate()
+ )
+ yield spark
+ spark.stop()
diff --git a/python/tests/kll_test.py b/python/tests/kll_test.py
index c142049..0ca850d 100644
--- a/python/tests/kll_test.py
+++ b/python/tests/kll_test.py
@@ -1,43 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
-import unittest
from pyspark.sql.types import StructType, StructField, DoubleType
-from pyspark.sql.session import SparkSession
-from datasketches_spark import get_dependency_classpath
+from datasketches import kll_doubles_sketch
from datasketches_spark.kll import *
-class PySparkBase(unittest.TestCase):
- @classmethod
- def setUpClass(cls):
- cls.spark = (
- SparkSession.builder
- .appName("test")
- .master("local[1]")
- .config("spark.driver.userClassPathFirst", "true")
- .config("spark.executor.userClassPathFirst", "true")
- .config("spark.driver.extraClassPath", get_dependency_classpath())
- .config("spark.executor.extraClassPath",
get_dependency_classpath())
- .getOrCreate()
- )
-
- @classmethod
- def tearDownClass(cls):
- cls.spark.stop()
-
-class TestKll(PySparkBase):
- def test_kll(self):
- spark = self.spark
-
- # Create a DataFrame
- n = 100000
- data = [(float(i),) for i in range(1, n + 1)]
- schema = StructType([StructField("value", DoubleType(), True)])
- df = spark.createDataFrame(data, schema)
- df_agg = df.agg(kll_sketch_double_agg_build("value", 160).alias("sketch"))
- df_agg.show()
-
- df_agg.select(
- kll_sketch_double_get_min("sketch").alias("min"),
- kll_sketch_double_get_max("sketch").alias("max"),
- kll_sketch_double_get_pmf("sketch", [25000, 30000, 75000]).alias("pmf")
- ).show()
+def test_kll_build(spark):
+ n = 100000
+ k = 160
+ data = [(float(i),) for i in range(1, n + 1)]
+ schema = StructType([StructField("value", DoubleType(), True)])
+ df = spark.createDataFrame(data, schema)
+
+ df_agg = df.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+
+ result = df_agg.select(
+ "sketch",
+ kll_sketch_double_get_min("sketch").alias("min"),
+ kll_sketch_double_get_max("sketch").alias("max"),
+ kll_sketch_double_get_pmf("sketch", [25000, 30000, 75000]).alias("pmf"),
+ kll_sketch_double_get_cdf("sketch", [20000, 50000, 95000],
False).alias("cdf")
+ ).first()
+ sk = result["sketch"]
+
+ assert(sk.n == n)
+ assert(sk.k == k)
+ assert(result["min"] == sk.get_min_value())
+ assert(result["max"] == sk.get_max_value())
+ assert(sk.get_pmf([25000, 30000, 75000]) == result["pmf"])
+ assert(sk.get_cdf([20000, 50000, 95000], False) == result["cdf"])
+
+def test_kll_merge(spark):
+ n = 75 # stay in exact mode
+ k = 200
+ data1 = [(float(i),) for i in range(1, n + 1)]
+ data2 = [(float(i),) for i in range(n + 1, 2 * n + 1)]
+ schema = StructType([StructField("value", DoubleType(), True)])
+ df1 = spark.createDataFrame(data1, schema)
+ df2 = spark.createDataFrame(data2, schema)
+
+ df_agg1 = df1.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+ df_agg2 = df2.agg(kll_sketch_double_agg_build("value", k).alias("sketch"))
+
+ result = df_agg1.union(df_agg2).select(
+ kll_sketch_double_agg_merge("sketch").alias("sketch")
+ ).first()
+ sk = result["sketch"]
+
+ assert(sk.n == 2 * n)
+ assert(sk.k == k)
+ assert(sk.get_min_value() == 1.0)
+ assert(sk.get_max_value() == 2 * n)
diff --git
a/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
b/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
index a058190..59299d2 100644
---
a/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
+++
b/src/main/scala/org/apache/spark/sql/datasketches/kll/types/KllDoublesSketchType.scala
@@ -25,6 +25,10 @@ import org.apache.spark.sql.types.{DataType, DataTypes,
UDTRegistration, UserDef
class KllDoublesSketchType extends UserDefinedType[KllDoublesSketch] with
Serializable {
override def sqlType: DataType = DataTypes.BinaryType
+ override def serializedPyClass: String = "bytes"
+
+ override def pyUDT: String = "datasketches_spark.KllDoublesSketchUDT"
+
override def serialize(wrapper: KllDoublesSketch): Array[Byte] = {
wrapper.toByteArray
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]