This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e6d44e449e1e [SPARK-51100][ML][PYTHON][CONNECT] Replace transformer
wrappers with helper model attribute relations
e6d44e449e1e is described below
commit e6d44e449e1eb0f1648e3e7834be7299424a4397
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Feb 6 09:06:39 2025 +0800
[SPARK-51100][ML][PYTHON][CONNECT] Replace transformer wrappers with helper
model attribute relations
### What changes were proposed in this pull request?
Replace transformer wrappers with helper model attributes
### Why are the changes needed?
to simplify the implementations
### Does this PR introduce _any_ user-facing change?
no, refactoring-only
### How was this patch tested?
existing tests should cover this change
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49819 from zhengruifeng/ml_connect_wrapper_to_helper.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Transformer | 10 ----
.../ml/clustering/PowerIterationClustering.scala | 28 ----------
.../scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 26 ---------
.../org/apache/spark/ml/stat/ChiSquareTest.scala | 44 +---------------
.../org/apache/spark/ml/stat/Correlation.scala | 29 ----------
.../spark/ml/stat/KolmogorovSmirnovTest.scala | 35 -------------
.../org/apache/spark/ml/util/ConnectHelper.scala | 61 ++++++++++++++++++++++
python/pyspark/ml/clustering.py | 19 +++----
python/pyspark/ml/feature.py | 4 +-
python/pyspark/ml/fpm.py | 25 ++++-----
python/pyspark/ml/stat.py | 42 ++++-----------
python/pyspark/ml/util.py | 60 ++++++++++++---------
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 7 ++-
.../apache/spark/sql/connect/ml/Serializer.scala | 12 +++++
14 files changed, 149 insertions(+), 253 deletions(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 247c9c912f5a..fc6a8166442a 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,9 +18,6 @@
# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml
non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
-########### Helper Model
-org.apache.spark.ml.util.ConnectHelper
-
########### Transformers
org.apache.spark.ml.feature.DCT
org.apache.spark.ml.feature.NGram
@@ -68,19 +65,12 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
org.apache.spark.ml.clustering.GaussianMixtureModel
org.apache.spark.ml.clustering.DistributedLDAModel
org.apache.spark.ml.clustering.LocalLDAModel
-org.apache.spark.ml.clustering.PowerIterationClusteringWrapper
# recommendation
org.apache.spark.ml.recommendation.ALSModel
# fpm
org.apache.spark.ml.fpm.FPGrowthModel
-org.apache.spark.ml.fpm.PrefixSpanWrapper
-
-# stat
-org.apache.spark.ml.stat.ChiSquareTestWrapper
-org.apache.spark.ml.stat.CorrelationWrapper
-org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper
# feature
org.apache.spark.ml.feature.RFormulaModel
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 6e7028d8f99e..8b2ee955d6a5 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.clustering
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
@@ -192,30 +191,3 @@ object PowerIterationClustering extends
DefaultParamsReadable[PowerIterationClus
@Since("2.4.0")
override def load(path: String): PowerIterationClustering = super.load(path)
}
-
-private[spark] class PowerIterationClusteringWrapper(override val uid: String)
- extends Transformer with PowerIterationClusteringParams with
DefaultParamsWritable {
-
- def this() = this(Identifiable.randomUID("PowerIterationClusteringWrapper"))
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- val pic = new PowerIterationClustering()
- .setK($(k))
- .setInitMode($(initMode))
- .setMaxIter($(maxIter))
- .setSrcCol($(srcCol))
- .setDstCol($(dstCol))
- get(weightCol) match {
- case Some(w) if w.nonEmpty => pic.setWeightCol(w)
- case _ =>
- }
- pic.assignClusters(dataset)
- }
-
- override def transformSchema(schema: StructType): StructType =
- new StructType()
- .add(StructField("id", LongType, nullable = false))
- .add(StructField("cluster", IntegerType, nullable = false))
-
- override def copy(extra: ParamMap): PowerIterationClusteringWrapper =
defaultCopy(extra)
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 099e42ee2749..3ebfbad310aa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -18,7 +18,6 @@
package org.apache.spark.ml.fpm
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.util.Instrumentation.instrumented
@@ -168,29 +167,4 @@ final class PrefixSpan(@Since("2.4.0") override val uid:
String) extends PrefixS
@Since("2.4.0")
override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
-
-}
-
-private[spark] class PrefixSpanWrapper(override val uid: String)
- extends Transformer with PrefixSpanParams {
-
- def this() = this(Identifiable.randomUID("prefixSpanWrapper"))
-
- override def transformSchema(schema: StructType): StructType = {
- new StructType()
- .add("sequence", ArrayType(schema($(sequenceCol)).dataType), nullable =
false)
- .add("freq", LongType, nullable = false)
- }
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- val prefixSpan = new PrefixSpan(uid)
- prefixSpan
- .setMinSupport($(minSupport))
- .setMaxPatternLength($(maxPatternLength))
- .setMaxLocalProjDBSize($(maxLocalProjDBSize))
- .setSequenceCol($(sequenceCol))
- .findFrequentSequentialPatterns(dataset)
- }
-
- override def copy(extra: ParamMap): PrefixSpanWrapper = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
index 863ec640e7b7..2207c20049a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
@@ -18,17 +18,12 @@
package org.apache.spark.ml.stat
import org.apache.spark.annotation.Since
-import org.apache.spark.ml.Transformer
import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.SchemaUtils
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.test.{ChiSqTest => OldChiSqTest}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
-
/**
* Chi-square hypothesis testing for categorical data.
@@ -105,40 +100,3 @@ object ChiSquareTest {
}
}
}
-
-/**
- * [[ChiSquareTest]] is not an Estimator/Transformer and thus needs to be
wrapped in a wrapper
- * to be compatible with Spark Connect.
- */
-private[spark] class ChiSquareTestWrapper(override val uid: String)
- extends Transformer with HasFeaturesCol with HasLabelCol {
-
- val flatten = new BooleanParam(this, "flatten",
- "If false, the returned DataFrame contains only a single Row, otherwise,
one row per feature.")
-
- setDefault(flatten -> false)
-
- def this() = this(Identifiable.randomUID("ChiSquareTestWrapper"))
-
- override def transformSchema(schema: StructType): StructType = {
- if ($(flatten)) {
- new StructType()
- .add("featureIndex", IntegerType, nullable = false)
- .add("pValue", DoubleType, nullable = false)
- .add("degreesOfFreedom", IntegerType, nullable = false)
- .add("statistic", DoubleType, nullable = false)
- } else {
- new StructType()
- .add("pValues", new VectorUDT, nullable = false)
- .add("degreesOfFreedom", ArrayType(IntegerType, containsNull = false),
nullable = false)
- .add("statistics", new VectorUDT, nullable = false)
- }
- }
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- ChiSquareTest.test(dataset.toDF(), $(featuresCol), $(labelCol), $(flatten))
- }
-
- override def copy(extra: ParamMap): ChiSquareTestWrapper = defaultCopy(extra)
-}
-
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
index 3c06abb23797..46fabc9808a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -20,11 +20,7 @@ package org.apache.spark.ml.stat
import scala.jdk.CollectionConverters._
import org.apache.spark.annotation.Since
-import org.apache.spark.ml._
import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasFeaturesCol
-import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -86,28 +82,3 @@ object Correlation {
corr(dataset, column, "pearson")
}
}
-
-/**
- * [[Correlation]] is not an Estimator/Transformer and thus needs to be
wrapped in a wrapper
- * to be compatible with Spark Connect.
- */
-private[spark] class CorrelationWrapper(override val uid: String)
- extends Transformer with HasFeaturesCol {
-
- val method = new Param[String](this, "method", "The correlation method to
use")
-
- setDefault(method -> "pearson")
-
- def this() = this(Identifiable.randomUID("CorrelationWrapper"))
-
- override def transformSchema(schema: StructType): StructType = {
- val name = s"${$(method)}(${$(featuresCol)})"
- StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable =
false)))
- }
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- Correlation.corr(dataset, $(featuresCol), $(method))
- }
-
- override def copy(extra: ParamMap): CorrelationWrapper = defaultCopy(extra)
-}
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
index 2fc4a856b564..c11163949ab1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
@@ -21,15 +21,11 @@ import scala.annotation.varargs
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.function.Function
-import org.apache.spark.ml._
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.ml.util._
import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.types._
/**
* Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
@@ -118,34 +114,3 @@ object KolmogorovSmirnovTest {
testResult.pValue, testResult.statistic)))
}
}
-
-
-/**
- * [[KolmogorovSmirnovTest]] is not an Estimator/Transformer and thus needs to
be wrapped
- * in a wrapper to be compatible with Spark Connect.
- */
-private[spark] class KolmogorovSmirnovTestWrapper(override val uid: String)
- extends Transformer with HasInputCol {
-
- val paramsArray = new DoubleArrayParam(this, "paramsArray",
- "The parameters to be used for the theoretical distribution.")
-
- val distName = new Param[String](this, "distName",
- "The name of the theoretical distribution to test against")
-
- setDefault(paramsArray -> Array.emptyDoubleArray)
-
- def this() = this(Identifiable.randomUID("KolmogorovSmirnovTestWrapper"))
-
- override def transformSchema(schema: StructType): StructType = {
- new StructType()
- .add("pValue", DoubleType, nullable = false)
- .add("statistic", DoubleType, nullable = false)
- }
-
- override def transform(dataset: Dataset[_]): DataFrame = {
- KolmogorovSmirnovTest.test(dataset, $(inputCol), $(distName),
$(paramsArray).toIndexedSeq: _*)
- }
-
- override def copy(extra: ParamMap): KolmogorovSmirnovTestWrapper =
defaultCopy(extra)
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index fb2e1a0c0b4e..834337692cab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -17,8 +17,11 @@
package org.apache.spark.ml.util
import org.apache.spark.ml.Model
+import org.apache.spark.ml.clustering.PowerIterationClustering
import org.apache.spark.ml.feature._
+import org.apache.spark.ml.fpm.PrefixSpan
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.stat._
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.types.StructType
@@ -54,6 +57,64 @@ private[spark] class ConnectHelper(override val uid: String)
extends Model[Conne
StopWordsRemover.getDefaultOrUS.toString
}
+ def chiSquareTest(
+ dataset: DataFrame,
+ featuresCol: String,
+ labelCol: String,
+ flatten: Boolean): DataFrame = {
+ ChiSquareTest.test(dataset, featuresCol, labelCol, flatten)
+ }
+
+ def correlation(
+ dataset: DataFrame,
+ column: String,
+ method: String): DataFrame = {
+ Correlation.corr(dataset, column, method)
+ }
+
+ def kolmogorovSmirnovTest(
+ dataset: DataFrame,
+ sampleCol: String,
+ distName: String,
+ params: Array[Double]): DataFrame = {
+ KolmogorovSmirnovTest.test(dataset, sampleCol, distName,
params.toIndexedSeq: _*)
+ }
+
+ def powerIterationClusteringAssignClusters(
+ dataset: DataFrame,
+ k: Int,
+ maxIter: Int,
+ initMode: String,
+ srcCol: String,
+ dstCol: String,
+ weightCol: String): DataFrame = {
+ val pic = new PowerIterationClustering()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setInitMode(initMode)
+ .setSrcCol(srcCol)
+ .setDstCol(dstCol)
+ if (weightCol.nonEmpty) {
+ pic.setWeightCol(weightCol)
+ }
+ pic.assignClusters(dataset)
+ }
+
+ def prefixSpanFindFrequentSequentialPatterns(
+ dataset: DataFrame,
+ minSupport: Double,
+ maxPatternLength: Int,
+ maxLocalProjDBSize: Long,
+ sequenceCol: String): DataFrame = {
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(minSupport)
+ .setMaxPatternLength(maxPatternLength)
+ .setMaxLocalProjDBSize(maxLocalProjDBSize)
+ .setSequenceCol(sequenceCol)
+ prefixSpan.findFrequentSequentialPatterns(dataset)
+ }
+
+
override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index dcd34ba365a5..a0710f435f3b 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -46,6 +46,7 @@ from pyspark.ml.util import (
GeneralJavaMLWritable,
HasTrainingSummary,
try_remote_attribute_relation,
+ invoke_helper_relation,
)
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams,
JavaWrapper
from pyspark.ml.common import inherit_doc
@@ -2155,16 +2156,16 @@ class PowerIterationClustering(
assert self._java_obj is not None
if is_remote():
- from pyspark.ml.wrapper import JavaTransformer
- from pyspark.ml.connect.serialize import serialize_ml_params
-
- instance = JavaTransformer()
- instance._java_obj =
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
- instance._serialized_ml_params = serialize_ml_params( # type:
ignore[attr-defined]
- self,
- dataset.sparkSession.client, # type: ignore[arg-type,operator]
+ return invoke_helper_relation(
+ "powerIterationClusteringAssignClusters",
+ dataset,
+ self.getK(),
+ self.getMaxIter(),
+ self.getInitMode(),
+ self.getSrcCol(),
+ self.getDstCol(),
+ self.getWeightCol() if self.isDefined(self.weightCol) else "",
)
- return instance.transform(dataset)
self._transfer_params_to_java()
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 3797b0f2b04c..6a4a9dc99875 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -4849,7 +4849,7 @@ class StringIndexerModel(
model._java_obj = helper._call_java(
"stringIndexerModelFromLabels",
model.uid,
- (list(labels), ArrayType(StringType(), False)),
+ (list(labels), ArrayType(StringType())),
)
else:
@@ -4891,7 +4891,7 @@ class StringIndexerModel(
model.uid,
(
[list(labels) for labels in arrayOfLabels],
- ArrayType(ArrayType(StringType(), False)),
+ ArrayType(ArrayType(StringType())),
),
)
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 0e46ecc45e93..64b7e5eae556 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -20,7 +20,12 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
from pyspark import keyword_only, since
from pyspark.sql import DataFrame
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable,
try_remote_attribute_relation
+from pyspark.ml.util import (
+ JavaMLWritable,
+ JavaMLReadable,
+ try_remote_attribute_relation,
+ invoke_helper_relation,
+)
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters,
Params
@@ -511,18 +516,14 @@ class PrefixSpan(JavaParams):
assert self._java_obj is not None
if is_remote():
- from pyspark.ml.wrapper import JavaTransformer
- from pyspark.ml.connect.serialize import serialize_ml_params
-
- instance = JavaTransformer()
- instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper"
- # The helper object is just a JavaTransformer without any Param
Mixin,
- # copying the params by .copy() or directly assigning the
_paramMap won't work
- instance._serialized_ml_params = serialize_ml_params( # type:
ignore[attr-defined]
- self,
- dataset.sparkSession.client, # type: ignore[arg-type,operator]
+ return invoke_helper_relation(
+ "prefixSpanFindFrequentSequentialPatterns",
+ dataset,
+ self.getMinSupport(),
+ self.getMaxPatternLength(),
+ self.getMaxLocalProjDBSize(),
+ self.getSequenceCol(),
)
- return instance.transform(dataset)
self._transfer_params_to_java()
jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index 16fa9f7edb80..b9d40d46fd09 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -22,9 +22,11 @@ from pyspark import since
from pyspark.ml.common import _java2py, _py2java
from pyspark.ml.linalg import Matrix, Vector
from pyspark.ml.wrapper import JavaWrapper, _jvm
+from pyspark.ml.util import invoke_helper_relation
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.functions import lit
+from pyspark.sql.types import ArrayType, DoubleType
from pyspark.sql.utils import is_remote
if TYPE_CHECKING:
@@ -104,17 +106,7 @@ class ChiSquareTest:
4.0
"""
if is_remote():
- from pyspark.ml.wrapper import JavaTransformer
- from pyspark.ml.connect.serialize import serialize_ml_params_values
-
- instance = JavaTransformer()
- instance._java_obj =
"org.apache.spark.ml.stat.ChiSquareTestWrapper"
- serialized_ml_params = serialize_ml_params_values(
- {"featuresCol": featuresCol, "labelCol": labelCol, "flatten":
flatten},
- dataset.sparkSession.client, # type: ignore[arg-type,operator]
- )
- instance._serialized_ml_params = serialized_ml_params # type:
ignore[attr-defined]
- return instance.transform(dataset)
+ return invoke_helper_relation("chiSquareTest", dataset,
featuresCol, labelCol, flatten)
else:
from pyspark.core.context import SparkContext
@@ -189,17 +181,7 @@ class Correlation:
[ 0.4 , 0.9486... , NaN, 1. ]])
"""
if is_remote():
- from pyspark.ml.wrapper import JavaTransformer
- from pyspark.ml.connect.serialize import serialize_ml_params_values
-
- instance = JavaTransformer()
- instance._java_obj = "org.apache.spark.ml.stat.CorrelationWrapper"
- serialized_ml_params = serialize_ml_params_values(
- {"featuresCol": column, "method": method},
- dataset.sparkSession.client, # type: ignore[arg-type,operator]
- )
- instance._serialized_ml_params = serialized_ml_params # type:
ignore[attr-defined]
- return instance.transform(dataset)
+ return invoke_helper_relation("correlation", dataset, column,
method)
else:
from pyspark.core.context import SparkContext
@@ -273,17 +255,13 @@ class KolmogorovSmirnovTest:
0.175
"""
if is_remote():
- from pyspark.ml.wrapper import JavaTransformer
- from pyspark.ml.connect.serialize import serialize_ml_params_values
-
- instance = JavaTransformer()
- instance._java_obj =
"org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper"
- serialized_ml_params = serialize_ml_params_values(
- {"inputCol": sampleCol, "distName": distName, "paramsArray":
list(params)},
- dataset.sparkSession.client, # type: ignore[arg-type,operator]
+ return invoke_helper_relation(
+ "kolmogorovSmirnovTest",
+ dataset,
+ sampleCol,
+ distName,
+ ([float(p) for p in params], ArrayType(DoubleType())),
)
- instance._serialized_ml_params = serialized_ml_params # type:
ignore[attr-defined]
- return instance.transform(dataset)
else:
from pyspark.core.context import SparkContext
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 666ebb0071c7..7b8ba57a1f8a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -78,6 +78,37 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
return cast(FuncT, wrapped)
+def invoke_helper_relation(method: str, *args: Any) -> "ConnectDataFrame":
+ from pyspark.ml.wrapper import JavaWrapper
+
+ helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+ return invoke_remote_attribute_relation(helper, method, *args)
+
+
+def invoke_remote_attribute_relation(
+ instance: "JavaWrapper", method: str, *args: Any
+) -> "ConnectDataFrame":
+ import pyspark.sql.connect.proto as pb2
+ from pyspark.ml.connect.util import _extract_id_methods
+ from pyspark.ml.connect.serialize import serialize
+
+ # The attribute returns a dataframe, we need to wrap it
+ # in the AttributeRelation
+ from pyspark.ml.connect.proto import AttributeRelation
+ from pyspark.sql.connect.session import SparkSession
+ from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+
+ session = SparkSession.getActiveSession()
+ assert session is not None
+
+ assert isinstance(instance._java_obj, str)
+
+ methods, obj_ref = _extract_id_methods(instance._java_obj)
+ methods.append(pb2.Fetch.Method(method=method,
args=serialize(session.client, *args)))
+ plan = AttributeRelation(obj_ref, methods)
+ return ConnectDataFrame(plan, session)
+
+
def try_remote_attribute_relation(f: FuncT) -> FuncT:
"""Mark the function/property that returns a Relation.
Eg, model.summary.roc"""
@@ -85,27 +116,7 @@ def try_remote_attribute_relation(f: FuncT) -> FuncT:
@functools.wraps(f)
def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
- import pyspark.sql.connect.proto as pb2
- from pyspark.ml.connect.util import _extract_id_methods
- from pyspark.ml.connect.serialize import serialize
-
- # The attribute returns a dataframe, we need to wrap it
- # in the AttributeRelation
- from pyspark.ml.connect.proto import AttributeRelation
- from pyspark.sql.connect.session import SparkSession
- from pyspark.sql.connect.dataframe import DataFrame as
ConnectDataFrame
-
- session = SparkSession.getActiveSession()
- assert session is not None
-
- assert isinstance(self._java_obj, str)
-
- methods, obj_ref = _extract_id_methods(self._java_obj)
- methods.append(
- pb2.Fetch.Method(method=f.__name__,
args=serialize(session.client, *args))
- )
- plan = AttributeRelation(obj_ref, methods)
- return ConnectDataFrame(plan, session)
+ return invoke_remote_attribute_relation(self, f.__name__, *args)
else:
return f(self, *args, **kwargs)
@@ -161,16 +172,12 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
session = dataset.sparkSession
assert session is not None
- if hasattr(self, "_serialized_ml_params"):
- params = self._serialized_ml_params
- else:
- params = serialize_ml_params(self, session.client) # type:
ignore[arg-type]
-
# Model is also a Transformer, so we much match Model first
if isinstance(self, Model):
from pyspark.ml.connect.proto import TransformerRelation
assert isinstance(self._java_obj, str)
+ params = serialize_ml_params(self, session.client)
return ConnectDataFrame(
TransformerRelation(
child=dataset._plan, name=self._java_obj,
ml_params=params, is_model=True
@@ -181,6 +188,7 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
from pyspark.ml.connect.proto import TransformerRelation
assert isinstance(self._java_obj, str)
+ params = serialize_ml_params(self, session.client)
return ConnectDataFrame(
TransformerRelation(
child=dataset._plan,
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index b613b2202137..84a26d9e4962 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -661,7 +661,12 @@ private[ml] object MLUtils {
"stringIndexerModelFromLabelsArray",
"countVectorizerModelFromVocabulary",
"stopWordsRemoverLoadDefaultStopWords",
- "stopWordsRemoverGetDefaultOrUS")))
+ "stopWordsRemoverGetDefaultOrUS",
+ "chiSquareTest",
+ "correlation",
+ "kolmogorovSmirnovTest",
+ "powerIterationClusteringAssignClusters",
+ "prefixSpanFindFrequentSequentialPatterns")))
private def validate(obj: Any, method: String): Unit = {
assert(obj != null)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
index 2bbc0b258cad..df07dd42bc42 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
@@ -167,6 +167,8 @@ private[ml] object Serializer {
case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
val array = literal.getArray
array.getElementType.getKindCase match {
+ case proto.DataType.KindCase.DOUBLE =>
+ (parseDoubleArray(array), classOf[Array[Double]])
case proto.DataType.KindCase.STRING =>
(parseStringArray(array), classOf[Array[String]])
case proto.DataType.KindCase.ARRAY =>
@@ -191,6 +193,16 @@ private[ml] object Serializer {
}
}
+ private def parseDoubleArray(array: proto.Expression.Literal.Array):
Array[Double] = {
+ val values = new Array[Double](array.getElementsCount)
+ var i = 0
+ while (i < array.getElementsCount) {
+ values(i) = array.getElements(i).getDouble
+ i += 1
+ }
+ values
+ }
+
private def parseStringArray(array: proto.Expression.Literal.Array):
Array[String] = {
val values = new Array[String](array.getElementsCount)
var i = 0
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]