This is an automated email from the ASF dual-hosted git repository. jmalkin pushed a commit to branch python in repository https://gitbox.apache.org/repos/asf/datasketches-spark.git
commit a1dd27810823a5e8e349a72d7ea2cb678683a56c Author: Jon Malkin <[email protected]> AuthorDate: Thu Jan 30 11:21:08 2025 -0800 WIP: Initial files for pyspark testing. Not fully working but checkpointing here --- python/datasketches_spark/__init__.py | 22 +++++++ python/datasketches_spark/common.py | 76 ++++++++++++++++++++++ python/datasketches_spark/kll.py | 71 ++++++++++++++++++++ .../expressions/KllDoublesSketchExpressions.scala | 6 +- 4 files changed, 172 insertions(+), 3 deletions(-) diff --git a/python/datasketches_spark/__init__.py b/python/datasketches_spark/__init__.py new file mode 100644 index 0000000..ef9effb --- /dev/null +++ b/python/datasketches_spark/__init__.py @@ -0,0 +1,22 @@ +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""The Apache DataSketches Pyspark Library for Spark + +Provided under the Apache License, Version 2.0 +<http://www.apache.org/licenses/LICENSE-2.0> +""" + +name = 'datasketches_spark' + +from .common import * +from .common import _invoke_function_over_columns +from .kll import * diff --git a/python/datasketches_spark/common.py b/python/datasketches_spark/common.py new file mode 100644 index 0000000..d487b05 --- /dev/null +++ b/python/datasketches_spark/common.py @@ -0,0 +1,76 @@ +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from pyspark import SparkContext +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal +from py4j.java_gateway import JavaClass +from typing import Any, TypeVar, Union, Callable +from functools import lru_cache + +ColumnOrName = Union[Column, str] +ColumnOrName_ = TypeVar("ColumnOrName_", bound=ColumnOrName) + +# Since we have functions from different packages, rather than the +# single 16k+ line functions class in core Spark, we'll have each +# sketch family grab its own functions class from the JVM and cache it + +def _get_jvm_class(name: str) -> JavaClass: + """ + Retrieves JVM class identified by name from + Java gateway associated with the current active Spark context. + """ + assert SparkContext._active_spark_context is not None + return getattr(SparkContext._active_spark_context._jvm, name) + +@lru_cache +def _get_jvm_function(cls: JavaClass, name: str) -> Callable: + """ + Retrieves JVM function identified by name from + Java gateway associated with sc. + """ + assert cls is not None + return getattr(cls, name) + +def _invoke_function(cls: JavaClass, name: str, *args: Any) -> Column: + """ + Invokes JVM function identified by name with args + and wraps the result with :class:`~pyspark.sql.Column`. + """ + #assert SparkContext._active_spark_context is not None + assert cls is not None + jf = _get_jvm_function(cls, name) + return Column(jf(*args)) + + +def _invoke_function_over_columns(cls: JavaClass, name: str, *cols: "ColumnOrName") -> Column: + """ + Invokes n-ary JVM function identified by name + and wraps the result with :class:`~pyspark.sql.Column`. + """ + return _invoke_function(cls, name, *(_to_java_column(col) for col in cols)) + + +# lazy init so we know the SparkContext exists first +_spark_functions_class: JavaClass = None + +def _get_spark_functions_class() -> JavaClass: + global _spark_functions_class + if _spark_functions_class is None: + _spark_functions_class = _get_jvm_class("org.apache.spark.sql.functions") + return _spark_functions_class + +# borrowed from PySpark +def _array_as_java_column(data: Union[list, tuple]) -> Column: + """ + Converts a Python list or tuple to a Spark DataFrame column. + """ + sc = SparkContext._active_spark_context + return _invoke_function(_get_spark_functions_class(), "array", _to_seq(sc, [_create_column_from_literal(x) for x in data])._jc) diff --git a/python/datasketches_spark/kll.py b/python/datasketches_spark/kll.py new file mode 100644 index 0000000..64a99e5 --- /dev/null +++ b/python/datasketches_spark/kll.py @@ -0,0 +1,71 @@ +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Optional, Tuple, Union +from py4j.java_gateway import JavaClass +from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal +from pyspark.sql.functions import lit +from pyspark.sql.utils import try_remote_functions + +from .common import ( + ColumnOrName, + _invoke_function, + _invoke_function_over_columns, + _get_jvm_class, + _array_as_java_column + ) + +_kll_functions_class: JavaClass = None + +def _get_kll_functions_class() -> JavaClass: + global _kll_functions_class + if _kll_functions_class is None: + _kll_functions_class = _get_jvm_class("org.apache.spark.sql.datasketches.kll.functions") + return _kll_functions_class + + +@try_remote_functions +def kll_sketch_double_agg_build(col: "ColumnOrName", k: Optional[Union[int, Column]] = None) -> Column: + if k is None: + return _invoke_function_over_columns(_get_kll_functions_class(), "kll_sketch_double_agg_build", col) + else: + _k = lit(k) if isinstance(k, int) else k + return _invoke_function_over_columns(_get_kll_functions_class(), "kll_sketch_double_agg_build", col, _k) + +@try_remote_functions +def kll_sketch_double_agg_merge(col: "ColumnOrName") -> Column: + return _invoke_function_over_columns(_get_kll_functions_class(), "kll_sketch_double_agg_merge", col) + +@try_remote_functions +def kll_sketch_double_get_min(col: "ColumnOrName") -> Column: + return _invoke_function(_get_kll_functions_class(), "kll_sketch_double_get_min", _to_java_column(col)) + +@try_remote_functions +def kll_sketch_double_get_max(col: "ColumnOrName") -> Column: + return _invoke_function(_get_kll_functions_class(), "kll_sketch_double_get_max", _to_java_column(col)) + +@try_remote_functions +def kll_sketch_double_get_pmf(col: "ColumnOrName", splitPoints: Union[List[float], Tuple[float], Column], isInclusive: bool = True) -> Column: + if isinstance(splitPoints, (list, tuple)): + splitPoints = _array_as_java_column(splitPoints) + elif isinstance(splitPoints, Column): + splitPoints = _to_java_column(splitPoints) + + return _invoke_function(_get_kll_functions_class(), "kll_sketch_double_get_pmf", col, splitPoints, isInclusive) + +@try_remote_functions +def kll_sketch_double_get_cdf(col: "ColumnOrName", splitPoints: Union[List[float], Column], isInclusive: bool = True) -> Column: + if isinstance(splitPoints, (list, tuple)): + splitPoints = _array_as_java_column(splitPoints) + elif isinstance(splitPoints, Column): + splitPoints = _to_java_column(splitPoints) + + return _invoke_function(_get_kll_functions_class(), "kll_sketch_double_get_cdf", col, splitPoints, isInclusive) diff --git a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala index f59b0dd..e326b32 100644 --- a/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala +++ b/src/main/scala/org/apache/spark/sql/datasketches/kll/expressions/KllDoublesSketchExpressions.scala @@ -72,7 +72,7 @@ case class KllDoublesSketchGetMin(sketchExpr: Expression) val code = s""" |${sketchEval.code} - |final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value}); + |final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value}); |final double ${ev.value} = $sketch.getMinItem(); """.stripMargin ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = sketchEval.isNull) @@ -124,7 +124,7 @@ case class KllDoublesSketchGetMax(sketchExpr: Expression) val code = s""" |${sketchEval.code} - |final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value}); + |final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value}); |final double ${ev.value} = $sketch.getMaxItem(); """.stripMargin ev.copy(code = CodeBlock(Seq(code), Seq.empty), isNull = sketchEval.isNull) @@ -291,7 +291,7 @@ case class KllDoublesSketchGetPmfCdf(sketchExpr: Expression, |${splitPointsEval.code} |if (!${sketchEval.isNull} && !${splitPointsEval.isNull}) { | org.apache.datasketches.quantilescommon.QuantileSearchCriteria searchCriteria = ${if (isInclusive) "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.INCLUSIVE" else "org.apache.datasketches.quantilescommon.QuantileSearchCriteria.EXCLUSIVE"}; - | final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.types.KllDoublesSketchType.wrap(${sketchEval.value}); + | final org.apache.datasketches.kll.KllDoublesSketch $sketch = org.apache.spark.sql.datasketches.kll.types.KllDoublesSketchType.wrap(${sketchEval.value}); | final double[] splitPoints = ((org.apache.spark.sql.catalyst.util.GenericArrayData)${splitPointsEval.value}).toDoubleArray(); | final double[] result = ${if (isPmf) s"$sketch.getPMF(splitPoints, searchCriteria)" else s"$sketch.getCDF(splitPoints, searchCriteria)"}; | org.apache.spark.sql.catalyst.util.GenericArrayData ${ev.value} = new org.apache.spark.sql.catalyst.util.GenericArrayData(result); --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
