This is an automated email from the ASF dual-hosted git repository.

jmalkin pushed a commit to branch python
in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git

commit a1dd27810823a5e8e349a72d7ea2cb678683a56c
Author: Jon Malkin <[email protected]>
AuthorDate: Thu Jan 30 11:21:08 2025 -0800

    WIP: Initial files for pyspark testing. Not fully working but checkpointing 
here
---
 python/datasketches_spark/__init__.py              | 22 +++++++
 python/datasketches_spark/common.py                | 76 ++++++++++++++++++++++
 python/datasketches_spark/kll.py                   | 71 ++++++++++++++++++++
 .../expressions/KllDoublesSketchExpressions.scala  |  6 +-
 4 files changed, 172 insertions(+), 3 deletions(-)

diff --git a/python/datasketches_spark/__init__.py 
b/python/datasketches_spark/__init__.py
new file mode 100644
index 0000000..ef9effb
--- /dev/null
+++ b/python/datasketches_spark/__init__.py
@@ -0,0 +1,22 @@
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""The Apache DataSketches Pyspark Library for Spark
+
+Provided under the Apache License, Version 2.0
+<http://www.apache.org/licenses/LICENSE-2.0>
+"""
+
+name = 'datasketches_spark'
+
+from .common import *
+from .common import _invoke_function_over_columns
+from .kll import *
diff --git a/python/datasketches_spark/common.py 
b/python/datasketches_spark/common.py
new file mode 100644
index 0000000..d487b05
--- /dev/null
+++ b/python/datasketches_spark/common.py
@@ -0,0 +1,76 @@
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from pyspark import SparkContext
+from pyspark.sql.column import Column, _to_java_column, _to_seq, 
_create_column_from_literal
+from py4j.java_gateway import JavaClass
+from typing import Any, TypeVar, Union, Callable
+from functools import lru_cache
+
+ColumnOrName = Union[Column, str]
+ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName)
+
+# Since we have functions from different packages, rather than the
+# single 16k+ line functions class in core Spark, we'll have each
+# sketch family grab its own functions class from the JVM and cache it
+
+def _get_jvm_class(name: str) -> JavaClass:
+    """
+    Retrieves JVM class identified by name from
+    Java gateway associated with the current active Spark context.
+    """
+    assert SparkContext._active_spark_context is not None
+    return getattr(SparkContext._active_spark_context._jvm, name)
+
+@lru_cache
+def _get_jvm_function(cls: JavaClass, name: str) -> Callable:
+    """
+    Retrieves JVM function identified by name from
+    Java gateway associated with sc.
+    """
+    assert cls is not None
+    return getattr(cls, name)
+
+def _invoke_function(cls: JavaClass, name: str, *args: Any) -> Column:
+    """
+    Invokes JVM function identified by name with args
+    and wraps the result with :class:`~pyspark.sql.Column`.
+    """
+    #assert SparkContext._active_spark_context is not None
+    assert cls is not None
+    jf = _get_jvm_function(cls, name)
+    return Column(jf(*args))
+
+
+def _invoke_function_over_columns(cls: JavaClass, name: str, *cols: 
"ColumnOrName") -> Column:
+    """
+    Invokes n-ary JVM function identified by name
+    and wraps the result with :class:`~pyspark.sql.Column`.
+    """
+    return _invoke_function(cls, name, *(_to_java_column(col) for col in cols))
+
+
+# lazy init so we know the SparkContext exists first
+_spark_functions_class: JavaClass = None
+
+def _get_spark_functions_class() -> JavaClass:
+    global _spark_functions_class
+    if _spark_functions_class is None:
+        _spark_functions_class = 
_get_jvm_class("org.apache.spark.sql.functions")
+    return _spark_functions_class
+
+# borrowed from PySpark
+def _array_as_java_column(data: Union[list, tuple]) -> Column:
+    """
+    Converts a Python list or tuple to a Spark DataFrame column.
+    """
+    sc = SparkContext._active_spark_context
+    return _invoke_function(_get_spark_functions_class(), "array", _to_seq(sc, 
[_create_column_from_literal(x) for x in data])._jc)
diff --git a/python/datasketches_spark/kll.py b/python/datasketches_spark/kll.py
new file mode 100644
index 0000000..64a99e5
--- /dev/null
+++ b/python/datasketches_spark/kll.py
@@ -0,0 +1,71 @@
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import List, Optional, Tuple, Union
+from py4j.java_gateway import JavaClass
+from pyspark.sql.column import Column, _to_java_column, _to_seq, 
_create_column_from_literal
+from pyspark.sql.functions import lit
+from pyspark.sql.utils import try_remote_functions
+
+from .common import (
+                     ColumnOrName,
+                     _invoke_function,
+                     _invoke_function_over_columns,
+                     _get_jvm_class,
+                     _array_as_java_column
+                    )
+
+_kll_functions_class: JavaClass = None
+
+def _get_kll_functions_class() -> JavaClass:
+    global _kll_functions_class
+    if _kll_functions_class is None:
+        _kll_functions_class = 
_get_jvm_class("org.apache.spark.sql.datasketches.kll.functions")
+    return _kll_functions_class
+
+
+@try_remote_functions
+def kll_sketch_double_agg_build(col: "ColumnOrName", k: Optional[Union[int, 
Column]] = None) -> Column:
+    if k is None:
+        return _invoke_function_over_columns(_get_kll_functions_class(), 
"kll_sketch_double_agg_build", col)
+    else:
+        _k = lit(k) if isinstance(k, int) else k
+        return _invoke_function_over_columns(_get_kll_functions_class(), 
"kll_sketch_double_agg_build", col, _k)
+
+@try_remote_functions
+def kll_sketch_double_agg_merge(col: "ColumnOrName") -> Column:
+    return _invoke_function_over_columns(_get_kll_functions_class(), 
"kll_sketch_double_agg_merge", col)
+
+@try_remote_functions
+def kll_sketch_double_get_min(col: "ColumnOrName") -> Column:
+    return _invoke_function(_get_kll_functions_class(), 
"kll_sketch_double_get_min", _to_java_column(col))
+
+@try_remote_functions
+def kll_sketch_double_get_max(col: "ColumnOrName") -> Column:
+    return _invoke_function(_get_kll_functions_class(), 
"kll_sketch_double_get_max", _to_java_column(col))
+
+@try_remote_functions
+def kll_sketch_double_get_pmf(col: "ColumnOrName", splitPoints: 
Union[List[float], Tuple[float], Column], isInclusive: bool = True) -> Column:
+    if isinstance(splitPoints, (list, tuple)):
+        splitPoints = _array_as_java_column(splitPoints)
+    elif isinstance(splitPoints, Column):
+        splitPoints = _to_java_column(splitPoints)
+
+    return _invoke_function(_get_kll_functions_class(), 
"kll_sketch_double_get_pmf", col, splitPoints, isInclusive)
+
+@try_remote_functions
+def kll_sketch_double_get_cdf(col: "ColumnOrName", splitPoints: 
Union[List[float], Column], isInclusive: bool = True) -> Column:
+    if isinstance(splitPoints, (list, tuple)):
+        splitPoints = _array_as_java_column(splitPoints)
+    elif isinstance(splitPoints, Column):
+        splitPoints = _to_java_column(splitPoints)
+
+    return _invoke_function(_get_kll_functions_class(), 
"kll_sketch_double_get_cdf", col, splitPoints, isInclusive)
diff --git 
a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
 
b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
index f59b0dd..e326b32 100644
--- 
a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
+++ 
b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala
@@ -72,7 +72,7 @@ case class KllDoublesSketchGetMin(sketchExpr: Expression)
     val code =
       s"""
          |${sketchEval.code}
-         |final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value});
+         |final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value});
          |final double ${ev.value} = $sketch.getMinItem();
        """.stripMargin
     ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = sketchEval.isNull)
@@ -124,7 +124,7 @@ case class KllDoublesSketchGetMax(sketchExpr: Expression)
     val code =
       s"""
          |${sketchEval.code}
-         |final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value});
+         |final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value});
          |final double ${ev.value} = $sketch.getMaxItem();
        """.stripMargin
     ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = sketchEval.isNull)
@@ -291,7 +291,7 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression,
          |${splitPointsEval.code}
          |if (!${sketchEval.isNull} && !${splitPointsEval.isNull}) {
          |  org.apache.datasketches.quantilescommon.QuantileSearchCriteria 
searchCriteria = ${if (isInclusive) 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else 
"org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"};
-         |  final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value});
+         |  final org.apache.datasketches.kll.KllDoublesSketch $sketch = 
org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value});
          |  final double[] splitPoints = 
((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray();
          |  final double[] result = ${if (isPmf) s"$sketch.getPMF(splitPoints, 
searchCriteria)" else s"$sketch.getCDF(splitPoints, searchCriteria)"};
          |  org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = 
new org.apache.spark.sql.catalyst.util.GenericArrayData(result);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to