This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e6d44e449e1e [SPARK-51100][ML][PYTHON][CONNECT] Replace transformer 
wrappers with helper model attribute relations
e6d44e449e1e is described below

commit e6d44e449e1eb0f1648e3e7834be7299424a4397
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Feb 6 09:06:39 2025 +0800

    [SPARK-51100][ML][PYTHON][CONNECT] Replace transformer wrappers with helper 
model attribute relations
    
    ### What changes were proposed in this pull request?
    Replace transformer wrappers with helper model attributes
    
    ### Why are the changes needed?
    to simplify the implementations
    
    ### Does this PR introduce _any_ user-facing change?
    no, refactoring-only
    
    ### How was this patch tested?
    existing tests should cover this change
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49819 from zhengruifeng/ml_connect_wrapper_to_helper.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../services/org.apache.spark.ml.Transformer       | 10 ----
 .../ml/clustering/PowerIterationClustering.scala   | 28 ----------
 .../scala/org/apache/spark/ml/fpm/PrefixSpan.scala | 26 ---------
 .../org/apache/spark/ml/stat/ChiSquareTest.scala   | 44 +---------------
 .../org/apache/spark/ml/stat/Correlation.scala     | 29 ----------
 .../spark/ml/stat/KolmogorovSmirnovTest.scala      | 35 -------------
 .../org/apache/spark/ml/util/ConnectHelper.scala   | 61 ++++++++++++++++++++++
 python/pyspark/ml/clustering.py                    | 19 +++----
 python/pyspark/ml/feature.py                       |  4 +-
 python/pyspark/ml/fpm.py                           | 25 ++++-----
 python/pyspark/ml/stat.py                          | 42 ++++-----------
 python/pyspark/ml/util.py                          | 60 ++++++++++++---------
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  7 ++-
 .../apache/spark/sql/connect/ml/Serializer.scala   | 12 +++++
 14 files changed, 149 insertions(+), 253 deletions(-)

diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 247c9c912f5a..fc6a8166442a 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -18,9 +18,6 @@
 # Spark Connect ML uses ServiceLoader to find out the supported Spark Ml 
non-model transformer.
 # So register the supported transformer here if you're trying to add a new one.
 
-########### Helper Model
-org.apache.spark.ml.util.ConnectHelper
-
 ########### Transformers
 org.apache.spark.ml.feature.DCT
 org.apache.spark.ml.feature.NGram
@@ -68,19 +65,12 @@ org.apache.spark.ml.clustering.BisectingKMeansModel
 org.apache.spark.ml.clustering.GaussianMixtureModel
 org.apache.spark.ml.clustering.DistributedLDAModel
 org.apache.spark.ml.clustering.LocalLDAModel
-org.apache.spark.ml.clustering.PowerIterationClusteringWrapper
 
 # recommendation
 org.apache.spark.ml.recommendation.ALSModel
 
 # fpm
 org.apache.spark.ml.fpm.FPGrowthModel
-org.apache.spark.ml.fpm.PrefixSpanWrapper
-
-# stat
-org.apache.spark.ml.stat.ChiSquareTestWrapper
-org.apache.spark.ml.stat.CorrelationWrapper
-org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper
 
 # feature
 org.apache.spark.ml.feature.RFormulaModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 6e7028d8f99e..8b2ee955d6a5 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.ml.clustering
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
@@ -192,30 +191,3 @@ object PowerIterationClustering extends 
DefaultParamsReadable[PowerIterationClus
   @Since("2.4.0")
   override def load(path: String): PowerIterationClustering = super.load(path)
 }
-
-private[spark] class PowerIterationClusteringWrapper(override val uid: String)
-  extends Transformer with PowerIterationClusteringParams with 
DefaultParamsWritable {
-
-  def this() = this(Identifiable.randomUID("PowerIterationClusteringWrapper"))
-
-  override def transform(dataset: Dataset[_]): DataFrame = {
-    val pic = new PowerIterationClustering()
-      .setK($(k))
-      .setInitMode($(initMode))
-      .setMaxIter($(maxIter))
-      .setSrcCol($(srcCol))
-      .setDstCol($(dstCol))
-    get(weightCol) match {
-      case Some(w) if w.nonEmpty => pic.setWeightCol(w)
-      case _ =>
-    }
-    pic.assignClusters(dataset)
-  }
-
-  override def transformSchema(schema: StructType): StructType =
-    new StructType()
-      .add(StructField("id", LongType, nullable = false))
-      .add(StructField("cluster", IntegerType, nullable = false))
-
-  override def copy(extra: ParamMap): PowerIterationClusteringWrapper = 
defaultCopy(extra)
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala 
b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 099e42ee2749..3ebfbad310aa 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -18,7 +18,6 @@
 package org.apache.spark.ml.fpm
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.ml.util.Instrumentation.instrumented
@@ -168,29 +167,4 @@ final class PrefixSpan(@Since("2.4.0") override val uid: 
String) extends PrefixS
 
   @Since("2.4.0")
   override def copy(extra: ParamMap): PrefixSpan = defaultCopy(extra)
-
-}
-
-private[spark] class PrefixSpanWrapper(override val uid: String)
-  extends Transformer with PrefixSpanParams {
-
-  def this() = this(Identifiable.randomUID("prefixSpanWrapper"))
-
-  override def transformSchema(schema: StructType): StructType = {
-    new StructType()
-      .add("sequence", ArrayType(schema($(sequenceCol)).dataType), nullable = 
false)
-      .add("freq", LongType, nullable = false)
-  }
-
-  override def transform(dataset: Dataset[_]): DataFrame = {
-    val prefixSpan = new PrefixSpan(uid)
-    prefixSpan
-      .setMinSupport($(minSupport))
-      .setMaxPatternLength($(maxPatternLength))
-      .setMaxLocalProjDBSize($(maxLocalProjDBSize))
-      .setSequenceCol($(sequenceCol))
-      .findFrequentSequentialPatterns(dataset)
-  }
-
-  override def copy(extra: ParamMap): PrefixSpanWrapper = defaultCopy(extra)
 }
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
index 863ec640e7b7..2207c20049a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/ChiSquareTest.scala
@@ -18,17 +18,12 @@
 package org.apache.spark.ml.stat
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol}
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util.SchemaUtils
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.stat.test.{ChiSqTest => OldChiSqTest}
 import org.apache.spark.sql._
 import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types._
-
 
 /**
  * Chi-square hypothesis testing for categorical data.
@@ -105,40 +100,3 @@ object ChiSquareTest {
     }
   }
 }
-
-/**
- * [[ChiSquareTest]] is not an Estimator/Transformer and thus needs to be 
wrapped in a wrapper
- * to be compatible with Spark Connect.
- */
-private[spark] class ChiSquareTestWrapper(override val uid: String)
-  extends Transformer with HasFeaturesCol with HasLabelCol {
-
-  val flatten = new BooleanParam(this, "flatten",
-    "If false, the returned DataFrame contains only a single Row, otherwise, 
one row per feature.")
-
-  setDefault(flatten -> false)
-
-  def this() = this(Identifiable.randomUID("ChiSquareTestWrapper"))
-
-  override def transformSchema(schema: StructType): StructType = {
-    if ($(flatten)) {
-      new StructType()
-        .add("featureIndex", IntegerType, nullable = false)
-        .add("pValue", DoubleType, nullable = false)
-        .add("degreesOfFreedom", IntegerType, nullable = false)
-        .add("statistic", DoubleType, nullable = false)
-    } else {
-      new StructType()
-        .add("pValues", new VectorUDT, nullable = false)
-        .add("degreesOfFreedom", ArrayType(IntegerType, containsNull = false), 
nullable = false)
-        .add("statistics", new VectorUDT, nullable = false)
-    }
-  }
-
-  override def transform(dataset: Dataset[_]): DataFrame = {
-    ChiSquareTest.test(dataset.toDF(), $(featuresCol), $(labelCol), $(flatten))
-  }
-
-  override def copy(extra: ParamMap): ChiSquareTestWrapper = defaultCopy(extra)
-}
-
diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
index 3c06abb23797..46fabc9808a7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/Correlation.scala
@@ -20,11 +20,7 @@ package org.apache.spark.ml.stat
 import scala.jdk.CollectionConverters._
 
 import org.apache.spark.annotation.Since
-import org.apache.spark.ml._
 import org.apache.spark.ml.linalg.{SQLDataTypes, Vector}
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasFeaturesCol
-import org.apache.spark.ml.util.Identifiable
 import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
 import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
 import org.apache.spark.sql.{DataFrame, Dataset, Row}
@@ -86,28 +82,3 @@ object Correlation {
     corr(dataset, column, "pearson")
   }
 }
-
-/**
- * [[Correlation]] is not an Estimator/Transformer and thus needs to be 
wrapped in a wrapper
- * to be compatible with Spark Connect.
- */
-private[spark] class CorrelationWrapper(override val uid: String)
-  extends Transformer with HasFeaturesCol {
-
-  val method = new Param[String](this, "method", "The correlation method to 
use")
-
-  setDefault(method -> "pearson")
-
-  def this() = this(Identifiable.randomUID("CorrelationWrapper"))
-
-  override def transformSchema(schema: StructType): StructType = {
-    val name = s"${$(method)}(${$(featuresCol)})"
-    StructType(Array(StructField(name, SQLDataTypes.MatrixType, nullable = 
false)))
-  }
-
-  override def transform(dataset: Dataset[_]): DataFrame = {
-    Correlation.corr(dataset, $(featuresCol), $(method))
-  }
-
-  override def copy(extra: ParamMap): CorrelationWrapper = defaultCopy(extra)
-}
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
index 2fc4a856b564..c11163949ab1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala
@@ -21,15 +21,11 @@ import scala.annotation.varargs
 
 import org.apache.spark.annotation.Since
 import org.apache.spark.api.java.function.Function
-import org.apache.spark.ml._
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasInputCol
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.stat.{Statistics => OldStatistics}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset}
 import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.types._
 
 /**
  * Conduct the two-sided Kolmogorov Smirnov (KS) test for data sampled from a
@@ -118,34 +114,3 @@ object KolmogorovSmirnovTest {
       testResult.pValue, testResult.statistic)))
   }
 }
-
-
-/**
- * [[KolmogorovSmirnovTest]] is not an Estimator/Transformer and thus needs to 
be wrapped
- * in a wrapper to be compatible with Spark Connect.
- */
-private[spark] class KolmogorovSmirnovTestWrapper(override val uid: String)
-  extends Transformer with HasInputCol {
-
-  val paramsArray = new DoubleArrayParam(this, "paramsArray",
-    "The parameters to be used for the theoretical distribution.")
-
-  val distName = new Param[String](this, "distName",
-    "The name of the theoretical distribution to test against")
-
-  setDefault(paramsArray -> Array.emptyDoubleArray)
-
-  def this() = this(Identifiable.randomUID("KolmogorovSmirnovTestWrapper"))
-
-  override def transformSchema(schema: StructType): StructType = {
-    new StructType()
-      .add("pValue", DoubleType, nullable = false)
-      .add("statistic", DoubleType, nullable = false)
-  }
-
-  override def transform(dataset: Dataset[_]): DataFrame = {
-    KolmogorovSmirnovTest.test(dataset, $(inputCol), $(distName), 
$(paramsArray).toIndexedSeq: _*)
-  }
-
-  override def copy(extra: ParamMap): KolmogorovSmirnovTestWrapper = 
defaultCopy(extra)
-}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
index fb2e1a0c0b4e..834337692cab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ConnectHelper.scala
@@ -17,8 +17,11 @@
 package org.apache.spark.ml.util
 
 import org.apache.spark.ml.Model
+import org.apache.spark.ml.clustering.PowerIterationClustering
 import org.apache.spark.ml.feature._
+import org.apache.spark.ml.fpm.PrefixSpan
 import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.stat._
 import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.types.StructType
 
@@ -54,6 +57,64 @@ private[spark] class ConnectHelper(override val uid: String) 
extends Model[Conne
     StopWordsRemover.getDefaultOrUS.toString
   }
 
+  def chiSquareTest(
+      dataset: DataFrame,
+      featuresCol: String,
+      labelCol: String,
+      flatten: Boolean): DataFrame = {
+    ChiSquareTest.test(dataset, featuresCol, labelCol, flatten)
+  }
+
+  def correlation(
+      dataset: DataFrame,
+      column: String,
+      method: String): DataFrame = {
+    Correlation.corr(dataset, column, method)
+  }
+
+  def kolmogorovSmirnovTest(
+      dataset: DataFrame,
+      sampleCol: String,
+      distName: String,
+      params: Array[Double]): DataFrame = {
+    KolmogorovSmirnovTest.test(dataset, sampleCol, distName, 
params.toIndexedSeq: _*)
+  }
+
+  def powerIterationClusteringAssignClusters(
+      dataset: DataFrame,
+      k: Int,
+      maxIter: Int,
+      initMode: String,
+      srcCol: String,
+      dstCol: String,
+      weightCol: String): DataFrame = {
+    val pic = new PowerIterationClustering()
+      .setK(k)
+      .setMaxIter(maxIter)
+      .setInitMode(initMode)
+      .setSrcCol(srcCol)
+      .setDstCol(dstCol)
+    if (weightCol.nonEmpty) {
+      pic.setWeightCol(weightCol)
+    }
+    pic.assignClusters(dataset)
+  }
+
+  def prefixSpanFindFrequentSequentialPatterns(
+      dataset: DataFrame,
+      minSupport: Double,
+      maxPatternLength: Int,
+      maxLocalProjDBSize: Long,
+      sequenceCol: String): DataFrame = {
+    val prefixSpan = new PrefixSpan()
+      .setMinSupport(minSupport)
+      .setMaxPatternLength(maxPatternLength)
+      .setMaxLocalProjDBSize(maxLocalProjDBSize)
+      .setSequenceCol(sequenceCol)
+    prefixSpan.findFrequentSequentialPatterns(dataset)
+  }
+
+
   override def copy(extra: ParamMap): ConnectHelper = defaultCopy(extra)
 
   override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF()
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index dcd34ba365a5..a0710f435f3b 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -46,6 +46,7 @@ from pyspark.ml.util import (
     GeneralJavaMLWritable,
     HasTrainingSummary,
     try_remote_attribute_relation,
+    invoke_helper_relation,
 )
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams, 
JavaWrapper
 from pyspark.ml.common import inherit_doc
@@ -2155,16 +2156,16 @@ class PowerIterationClustering(
         assert self._java_obj is not None
 
         if is_remote():
-            from pyspark.ml.wrapper import JavaTransformer
-            from pyspark.ml.connect.serialize import serialize_ml_params
-
-            instance = JavaTransformer()
-            instance._java_obj = 
"org.apache.spark.ml.clustering.PowerIterationClusteringWrapper"
-            instance._serialized_ml_params = serialize_ml_params(  # type: 
ignore[attr-defined]
-                self,
-                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
+            return invoke_helper_relation(
+                "powerIterationClusteringAssignClusters",
+                dataset,
+                self.getK(),
+                self.getMaxIter(),
+                self.getInitMode(),
+                self.getSrcCol(),
+                self.getDstCol(),
+                self.getWeightCol() if self.isDefined(self.weightCol) else "",
             )
-            return instance.transform(dataset)
 
         self._transfer_params_to_java()
 
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 3797b0f2b04c..6a4a9dc99875 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -4849,7 +4849,7 @@ class StringIndexerModel(
             model._java_obj = helper._call_java(
                 "stringIndexerModelFromLabels",
                 model.uid,
-                (list(labels), ArrayType(StringType(), False)),
+                (list(labels), ArrayType(StringType())),
             )
 
         else:
@@ -4891,7 +4891,7 @@ class StringIndexerModel(
                 model.uid,
                 (
                     [list(labels) for labels in arrayOfLabels],
-                    ArrayType(ArrayType(StringType(), False)),
+                    ArrayType(ArrayType(StringType())),
                 ),
             )
 
diff --git a/python/pyspark/ml/fpm.py b/python/pyspark/ml/fpm.py
index 0e46ecc45e93..64b7e5eae556 100644
--- a/python/pyspark/ml/fpm.py
+++ b/python/pyspark/ml/fpm.py
@@ -20,7 +20,12 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
 
 from pyspark import keyword_only, since
 from pyspark.sql import DataFrame
-from pyspark.ml.util import JavaMLWritable, JavaMLReadable, 
try_remote_attribute_relation
+from pyspark.ml.util import (
+    JavaMLWritable,
+    JavaMLReadable,
+    try_remote_attribute_relation,
+    invoke_helper_relation,
+)
 from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
 from pyspark.ml.param.shared import HasPredictionCol, Param, TypeConverters, 
Params
 
@@ -511,18 +516,14 @@ class PrefixSpan(JavaParams):
         assert self._java_obj is not None
 
         if is_remote():
-            from pyspark.ml.wrapper import JavaTransformer
-            from pyspark.ml.connect.serialize import serialize_ml_params
-
-            instance = JavaTransformer()
-            instance._java_obj = "org.apache.spark.ml.fpm.PrefixSpanWrapper"
-            # The helper object is just a JavaTransformer without any Param 
Mixin,
-            # copying the params by .copy() or directly assigning the 
_paramMap won't work
-            instance._serialized_ml_params = serialize_ml_params(  # type: 
ignore[attr-defined]
-                self,
-                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
+            return invoke_helper_relation(
+                "prefixSpanFindFrequentSequentialPatterns",
+                dataset,
+                self.getMinSupport(),
+                self.getMaxPatternLength(),
+                self.getMaxLocalProjDBSize(),
+                self.getSequenceCol(),
             )
-            return instance.transform(dataset)
 
         self._transfer_params_to_java()
         jdf = self._java_obj.findFrequentSequentialPatterns(dataset._jdf)
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index 16fa9f7edb80..b9d40d46fd09 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -22,9 +22,11 @@ from pyspark import since
 from pyspark.ml.common import _java2py, _py2java
 from pyspark.ml.linalg import Matrix, Vector
 from pyspark.ml.wrapper import JavaWrapper, _jvm
+from pyspark.ml.util import invoke_helper_relation
 from pyspark.sql.column import Column
 from pyspark.sql.dataframe import DataFrame
 from pyspark.sql.functions import lit
+from pyspark.sql.types import ArrayType, DoubleType
 from pyspark.sql.utils import is_remote
 
 if TYPE_CHECKING:
@@ -104,17 +106,7 @@ class ChiSquareTest:
         4.0
         """
         if is_remote():
-            from pyspark.ml.wrapper import JavaTransformer
-            from pyspark.ml.connect.serialize import serialize_ml_params_values
-
-            instance = JavaTransformer()
-            instance._java_obj = 
"org.apache.spark.ml.stat.ChiSquareTestWrapper"
-            serialized_ml_params = serialize_ml_params_values(
-                {"featuresCol": featuresCol, "labelCol": labelCol, "flatten": 
flatten},
-                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
-            )
-            instance._serialized_ml_params = serialized_ml_params  # type: 
ignore[attr-defined]
-            return instance.transform(dataset)
+            return invoke_helper_relation("chiSquareTest", dataset, 
featuresCol, labelCol, flatten)
 
         else:
             from pyspark.core.context import SparkContext
@@ -189,17 +181,7 @@ class Correlation:
                      [ 0.4       ,  0.9486... ,         NaN,  1.        ]])
         """
         if is_remote():
-            from pyspark.ml.wrapper import JavaTransformer
-            from pyspark.ml.connect.serialize import serialize_ml_params_values
-
-            instance = JavaTransformer()
-            instance._java_obj = "org.apache.spark.ml.stat.CorrelationWrapper"
-            serialized_ml_params = serialize_ml_params_values(
-                {"featuresCol": column, "method": method},
-                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
-            )
-            instance._serialized_ml_params = serialized_ml_params  # type: 
ignore[attr-defined]
-            return instance.transform(dataset)
+            return invoke_helper_relation("correlation", dataset, column, 
method)
 
         else:
             from pyspark.core.context import SparkContext
@@ -273,17 +255,13 @@ class KolmogorovSmirnovTest:
         0.175
         """
         if is_remote():
-            from pyspark.ml.wrapper import JavaTransformer
-            from pyspark.ml.connect.serialize import serialize_ml_params_values
-
-            instance = JavaTransformer()
-            instance._java_obj = 
"org.apache.spark.ml.stat.KolmogorovSmirnovTestWrapper"
-            serialized_ml_params = serialize_ml_params_values(
-                {"inputCol": sampleCol, "distName": distName, "paramsArray": 
list(params)},
-                dataset.sparkSession.client,  # type: ignore[arg-type,operator]
+            return invoke_helper_relation(
+                "kolmogorovSmirnovTest",
+                dataset,
+                sampleCol,
+                distName,
+                ([float(p) for p in params], ArrayType(DoubleType())),
             )
-            instance._serialized_ml_params = serialized_ml_params  # type: 
ignore[attr-defined]
-            return instance.transform(dataset)
 
         else:
             from pyspark.core.context import SparkContext
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 666ebb0071c7..7b8ba57a1f8a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -78,6 +78,37 @@ def try_remote_intermediate_result(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def invoke_helper_relation(method: str, *args: Any) -> "ConnectDataFrame":
+    from pyspark.ml.wrapper import JavaWrapper
+
+    helper = JavaWrapper(java_obj=ML_CONNECT_HELPER_ID)
+    return invoke_remote_attribute_relation(helper, method, *args)
+
+
+def invoke_remote_attribute_relation(
+    instance: "JavaWrapper", method: str, *args: Any
+) -> "ConnectDataFrame":
+    import pyspark.sql.connect.proto as pb2
+    from pyspark.ml.connect.util import _extract_id_methods
+    from pyspark.ml.connect.serialize import serialize
+
+    # The attribute returns a dataframe, we need to wrap it
+    # in the AttributeRelation
+    from pyspark.ml.connect.proto import AttributeRelation
+    from pyspark.sql.connect.session import SparkSession
+    from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+
+    session = SparkSession.getActiveSession()
+    assert session is not None
+
+    assert isinstance(instance._java_obj, str)
+
+    methods, obj_ref = _extract_id_methods(instance._java_obj)
+    methods.append(pb2.Fetch.Method(method=method, 
args=serialize(session.client, *args)))
+    plan = AttributeRelation(obj_ref, methods)
+    return ConnectDataFrame(plan, session)
+
+
 def try_remote_attribute_relation(f: FuncT) -> FuncT:
     """Mark the function/property that returns a Relation.
     Eg, model.summary.roc"""
@@ -85,27 +116,7 @@ def try_remote_attribute_relation(f: FuncT) -> FuncT:
     @functools.wraps(f)
     def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
         if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
-            import pyspark.sql.connect.proto as pb2
-            from pyspark.ml.connect.util import _extract_id_methods
-            from pyspark.ml.connect.serialize import serialize
-
-            # The attribute returns a dataframe, we need to wrap it
-            # in the AttributeRelation
-            from pyspark.ml.connect.proto import AttributeRelation
-            from pyspark.sql.connect.session import SparkSession
-            from pyspark.sql.connect.dataframe import DataFrame as 
ConnectDataFrame
-
-            session = SparkSession.getActiveSession()
-            assert session is not None
-
-            assert isinstance(self._java_obj, str)
-
-            methods, obj_ref = _extract_id_methods(self._java_obj)
-            methods.append(
-                pb2.Fetch.Method(method=f.__name__, 
args=serialize(session.client, *args))
-            )
-            plan = AttributeRelation(obj_ref, methods)
-            return ConnectDataFrame(plan, session)
+            return invoke_remote_attribute_relation(self, f.__name__, *args)
         else:
             return f(self, *args, **kwargs)
 
@@ -161,16 +172,12 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
             session = dataset.sparkSession
             assert session is not None
 
-            if hasattr(self, "_serialized_ml_params"):
-                params = self._serialized_ml_params
-            else:
-                params = serialize_ml_params(self, session.client)  # type: 
ignore[arg-type]
-
             # Model is also a Transformer, so we much match Model first
             if isinstance(self, Model):
                 from pyspark.ml.connect.proto import TransformerRelation
 
                 assert isinstance(self._java_obj, str)
+                params = serialize_ml_params(self, session.client)
                 return ConnectDataFrame(
                     TransformerRelation(
                         child=dataset._plan, name=self._java_obj, 
ml_params=params, is_model=True
@@ -181,6 +188,7 @@ def try_remote_transform_relation(f: FuncT) -> FuncT:
                 from pyspark.ml.connect.proto import TransformerRelation
 
                 assert isinstance(self._java_obj, str)
+                params = serialize_ml_params(self, session.client)
                 return ConnectDataFrame(
                     TransformerRelation(
                         child=dataset._plan,
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index b613b2202137..84a26d9e4962 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -661,7 +661,12 @@ private[ml] object MLUtils {
         "stringIndexerModelFromLabelsArray",
         "countVectorizerModelFromVocabulary",
         "stopWordsRemoverLoadDefaultStopWords",
-        "stopWordsRemoverGetDefaultOrUS")))
+        "stopWordsRemoverGetDefaultOrUS",
+        "chiSquareTest",
+        "correlation",
+        "kolmogorovSmirnovTest",
+        "powerIterationClusteringAssignClusters",
+        "prefixSpanFindFrequentSequentialPatterns")))
 
   private def validate(obj: Any, method: String): Unit = {
     assert(obj != null)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
index 2bbc0b258cad..df07dd42bc42 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/Serializer.scala
@@ -167,6 +167,8 @@ private[ml] object Serializer {
           case proto.Expression.Literal.LiteralTypeCase.ARRAY =>
             val array = literal.getArray
             array.getElementType.getKindCase match {
+              case proto.DataType.KindCase.DOUBLE =>
+                (parseDoubleArray(array), classOf[Array[Double]])
               case proto.DataType.KindCase.STRING =>
                 (parseStringArray(array), classOf[Array[String]])
               case proto.DataType.KindCase.ARRAY =>
@@ -191,6 +193,16 @@ private[ml] object Serializer {
     }
   }
 
+  private def parseDoubleArray(array: proto.Expression.Literal.Array): 
Array[Double] = {
+    val values = new Array[Double](array.getElementsCount)
+    var i = 0
+    while (i < array.getElementsCount) {
+      values(i) = array.getElements(i).getDouble
+      i += 1
+    }
+    values
+  }
+
   private def parseStringArray(array: proto.Expression.Literal.Array): 
Array[String] = {
     val values = new Array[String](array.getElementsCount)
     var i = 0


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to