wbo4958 commented on code in PR #48791:
URL: https://github.com/apache/spark/pull/48791#discussion_r1856319287


##########
python/pyspark/ml/remote/readwrite.py:
##########
@@ -0,0 +1,100 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from typing import cast, Type, TYPE_CHECKING
+
+import pyspark.sql.connect.proto as pb2
+from pyspark.ml.remote.serialize import serialize_ml_params, deserialize, 
deserialize_param
+from pyspark.ml.util import MLWriter, MLReader, RL
+from pyspark.ml.wrapper import JavaWrapper
+
+if TYPE_CHECKING:
+    from pyspark.ml.util import JavaMLReadable, JavaMLWritable
+    from pyspark.core.context import SparkContext
+
+
+class RemoteMLWriter(MLWriter):
+    def __init__(self, instance: "JavaMLWritable") -> None:
+        super().__init__()
+        self._instance = instance
+
+    @property
+    def sc(self) -> "SparkContext":
+        raise RuntimeError("Accessing SparkContext is not supported on 
Connect")
+
+    def save(self, path: str) -> None:
+        from pyspark.ml.wrapper import JavaModel
+
+        if isinstance(self._instance, JavaModel):
+            from pyspark.sql.connect.session import SparkSession
+
+            session = SparkSession.getActiveSession()
+            assert session is not None
+            instance = cast("JavaModel", self._instance)
+            params = serialize_ml_params(instance, session.client)
+
+            assert isinstance(instance._java_obj, str)
+            writer = pb2.MlCommand.Writer(
+                model_ref=pb2.ModelRef(id=instance._java_obj),
+                params=params,
+                path=path,
+                should_overwrite=self.shouldOverwrite,
+                options=self.optionMap,
+            )
+            req = session.client._execute_plan_request_with_metadata()
+            req.plan.ml_command.write.CopyFrom(writer)
+            session.client.execute_ml(req)
+
+
+class RemoteMLReader(MLReader[RL]):
+    def __init__(self, clazz: Type["JavaMLReadable[RL]"]) -> None:
+        super().__init__()
+        self._clazz = clazz

Review Comment:
   Yes, true. The estimator and model are extending from JavaMLReadable. so 
it's generalizable.



##########
python/pyspark/ml/remote/proto.py:
##########
@@ -0,0 +1,86 @@
+#

Review Comment:
   I could also move the existing "remote" to the connect if necessary.



##########
sql/connect/common/src/main/protobuf/spark/connect/base.proto:
##########
@@ -384,6 +386,9 @@ message ExecutePlanResponse {
     // Response for command that checkpoints a DataFrame.
     CheckpointCommandResult checkpoint_command_result = 19;
 
+    // ML command response
+    MlCommandResponse ml_command_result = 100;

Review Comment:
   It's already there.



##########
mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator:
##########
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+org.apache.spark.ml.classification.LogisticRegression

Review Comment:
   sounds good. Added the comment here



##########
python/pyspark/ml/remote/serialize.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any, List, TYPE_CHECKING, Mapping, Optional
+
+import pyspark.sql.connect.proto as pb2
+from pyspark.ml.linalg import (
+    Vectors,
+    Matrices,
+    DenseVector,
+    SparseVector,
+    DenseMatrix,
+    SparseMatrix,
+)
+from pyspark.sql.connect.dataframe import DataFrame as RemoteDataFrame
+from pyspark.sql.connect.expressions import LiteralExpression
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.client import SparkConnectClient
+    from pyspark.ml.param import Params
+
+
+def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Param:
+    if isinstance(value, DenseVector):
+        return 
pb2.Param(vector=pb2.Vector(dense=pb2.Vector.Dense(value=value.values.tolist())))
+    elif isinstance(value, SparseVector):
+        return pb2.Param(
+            vector=pb2.Vector(
+                sparse=pb2.Vector.Sparse(
+                    size=value.size, index=value.indices.tolist(), 
value=value.values.tolist()
+                )
+            )
+        )
+    elif isinstance(value, DenseMatrix):
+        return pb2.Param(
+            matrix=pb2.Matrix(
+                dense=pb2.Matrix.Dense(
+                    num_rows=value.numRows, num_cols=value.numCols, 
value=value.values.tolist()
+                )
+            )
+        )
+    elif isinstance(value, SparseMatrix):
+        return pb2.Param(
+            matrix=pb2.Matrix(
+                sparse=pb2.Matrix.Sparse(
+                    num_rows=value.numRows,
+                    num_cols=value.numCols,
+                    colptr=value.colPtrs.tolist(),
+                    row_index=value.rowIndices.tolist(),
+                    value=value.values.tolist(),
+                )
+            )
+        )
+    else:
+        literal = LiteralExpression._from_value(value).to_plan(client).literal
+        return pb2.Param(literal=literal)
+
+
+def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
+    result = []
+    for arg in args:
+        if isinstance(arg, RemoteDataFrame):
+            
result.append(pb2.FetchModelAttr.Args(input=arg._plan.plan(client)))
+        else:
+            result.append(pb2.FetchModelAttr.Args(param=serialize_param(arg, 
client)))
+    return result
+
+
+def deserialize_param(param: pb2.Param) -> Any:
+    if param.HasField("literal"):
+        return LiteralExpression._to_value(param.literal)
+    if param.HasField("vector"):
+        vector = param.vector
+        # TODO support sparse vector
+        if vector.HasField("dense"):
+            return Vectors.dense(vector.dense.value)
+        raise ValueError("TODO, support sparse vector")
+
+    if param.HasField("matrix"):
+        matrix = param.matrix
+        # TODO support sparse matrix
+        if matrix.HasField("dense") and not matrix.dense.is_transposed:
+            return Matrices.dense(
+                matrix.dense.num_rows,
+                matrix.dense.num_cols,
+                matrix.dense.value,
+            )
+        raise ValueError("TODO, support sparse matrix")

Review Comment:
   Hmm, I will support it in this PR.



##########
python/pyspark/ml/tests/test_training_summary.py:
##########
@@ -122,94 +119,6 @@ def test_glr_summary(self):
         sameSummary = model.evaluate(df)
         self.assertAlmostEqual(sameSummary.deviance, s.deviance)
 
-    def test_binary_logistic_regression_summary(self):

Review Comment:
   I've moved the tests into test_classification.py which is the base class for 
both classic spark and connect spark. See 
https://github.com/apache/spark/pull/48791/files#diff-9bb8f507634f33717f8758848a6369d9f57a0aeb6629b16648a38fbd15352892R39
  and 
https://github.com/apache/spark/pull/48791/files#diff-a4da55d9f9f7dd55641f346cffe7ee87432f5ca06da40a75dddca514e6c6a648R184



##########
python/pyspark/ml/remote/util.py:
##########
@@ -0,0 +1,259 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import functools
+import os
+from typing import Any, cast, TypeVar, Callable, TYPE_CHECKING, Type
+
+import pyspark.sql.connect.proto as pb2
+from pyspark.ml.remote.serialize import serialize_ml_params, serialize, 
deserialize
+from pyspark.sql import is_remote
+from pyspark.sql.connect.dataframe import DataFrame as RemoteDataFrame
+
+if TYPE_CHECKING:
+    from pyspark.ml.wrapper import JavaWrapper, JavaEstimator
+    from pyspark.ml.util import JavaMLReadable, JavaMLWritable
+
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
+
+
+def try_remote_intermediate_result(f: FuncT) -> FuncT:
+    """Mark the function/property that returns the intermediate result of the 
remote call.
+    Eg, model.summary"""
+
+    @functools.wraps(f)
+    def wrapped(self: "JavaWrapper") -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            return f"{self._java_obj}.{f.__name__}"
+        else:
+            return f(self)
+
+    return cast(FuncT, wrapped)
+
+
+def try_remote_attribute_relation(f: FuncT) -> FuncT:
+    """Mark the function/property that returns a Relation.
+    Eg, model.summary.roc"""
+
+    @functools.wraps(f)
+    def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            # The attribute returns a dataframe, we need to wrap it
+            # in the _ModelAttributeRelationPlan
+            from pyspark.ml.remote.proto import _ModelAttributeRelationPlan
+            from pyspark.sql.connect.session import SparkSession
+
+            assert isinstance(self._java_obj, str)

Review Comment:
   This assert here is typically for the python typing issue. We're 
reusing`self._java_obj` in the PR, but the self._java_obj is a  "JavaObject" in 
the classic spark.



##########
python/pyspark/ml/remote/serialize.py:
##########
@@ -0,0 +1,120 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from typing import Any, List, TYPE_CHECKING, Mapping, Optional
+
+import pyspark.sql.connect.proto as pb2
+from pyspark.ml.linalg import (
+    Vectors,
+    Matrices,
+    DenseVector,
+    SparseVector,
+    DenseMatrix,
+    SparseMatrix,
+)
+from pyspark.sql.connect.dataframe import DataFrame as RemoteDataFrame
+from pyspark.sql.connect.expressions import LiteralExpression
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect.client import SparkConnectClient
+    from pyspark.ml.param import Params
+
+
+def serialize_param(value: Any, client: "SparkConnectClient") -> pb2.Param:
+    if isinstance(value, DenseVector):
+        return 
pb2.Param(vector=pb2.Vector(dense=pb2.Vector.Dense(value=value.values.tolist())))
+    elif isinstance(value, SparseVector):
+        return pb2.Param(
+            vector=pb2.Vector(
+                sparse=pb2.Vector.Sparse(
+                    size=value.size, index=value.indices.tolist(), 
value=value.values.tolist()
+                )
+            )
+        )
+    elif isinstance(value, DenseMatrix):
+        return pb2.Param(
+            matrix=pb2.Matrix(
+                dense=pb2.Matrix.Dense(
+                    num_rows=value.numRows, num_cols=value.numCols, 
value=value.values.tolist()
+                )
+            )
+        )
+    elif isinstance(value, SparseMatrix):
+        return pb2.Param(
+            matrix=pb2.Matrix(
+                sparse=pb2.Matrix.Sparse(
+                    num_rows=value.numRows,
+                    num_cols=value.numCols,
+                    colptr=value.colPtrs.tolist(),
+                    row_index=value.rowIndices.tolist(),
+                    value=value.values.tolist(),
+                )
+            )
+        )
+    else:
+        literal = LiteralExpression._from_value(value).to_plan(client).literal
+        return pb2.Param(literal=literal)
+
+
+def serialize(client: "SparkConnectClient", *args: Any) -> List[Any]:
+    result = []
+    for arg in args:
+        if isinstance(arg, RemoteDataFrame):
+            
result.append(pb2.FetchModelAttr.Args(input=arg._plan.plan(client)))
+        else:
+            result.append(pb2.FetchModelAttr.Args(param=serialize_param(arg, 
client)))
+    return result
+
+
+def deserialize_param(param: pb2.Param) -> Any:
+    if param.HasField("literal"):
+        return LiteralExpression._to_value(param.literal)
+    if param.HasField("vector"):
+        vector = param.vector
+        # TODO support sparse vector
+        if vector.HasField("dense"):
+            return Vectors.dense(vector.dense.value)
+        raise ValueError("TODO, support sparse vector")
+
+    if param.HasField("matrix"):
+        matrix = param.matrix
+        # TODO support sparse matrix
+        if matrix.HasField("dense") and not matrix.dense.is_transposed:
+            return Matrices.dense(
+                matrix.dense.num_rows,
+                matrix.dense.num_cols,
+                matrix.dense.value,
+            )
+        raise ValueError("TODO, support sparse matrix")
+
+    raise ValueError("Unsupported param type")
+
+
+def deserialize(ml_command_result: Optional[pb2.MlCommandResponse]) -> Any:
+    assert ml_command_result is not None
+    if ml_command_result.HasField("operator_info"):
+        return ml_command_result.operator_info
+
+    if ml_command_result.HasField("param"):
+        return deserialize_param(ml_command_result.param)
+    raise ValueError()

Review Comment:
   Sounds good. Done



##########
python/pyspark/sql/connect/client/core.py:
##########
@@ -1143,6 +1148,26 @@ def execute_command_as_iterator(
                     },
                 )
 
+    def execute_ml(self, req: pb2.ExecutePlanRequest) -> 
Optional[pb2.MlCommandResponse]:
+        """
+        Execute the ML command request and return ML response result
+        Parameters
+        ----------
+        req : pb2.ExecutePlanRequest
+            Proto representation of the plan.
+        """
+        logger.info("Execute ML")
+        try:
+            for attempt in self._retrying():
+                with attempt:
+                    for b in self._stub.ExecutePlan(req, 
metadata=self._builder.metadata()):
+                        if b.HasField("ml_command_result"):
+                            return b.ml_command_result

Review Comment:
   Yeah. the new commit have moved the ml_command to the existing command. 
That's really cool.



##########
python/pyspark/ml/remote/util.py:
##########
@@ -0,0 +1,259 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import functools
+import os
+from typing import Any, cast, TypeVar, Callable, TYPE_CHECKING, Type
+
+import pyspark.sql.connect.proto as pb2
+from pyspark.ml.remote.serialize import serialize_ml_params, serialize, 
deserialize
+from pyspark.sql import is_remote
+from pyspark.sql.connect.dataframe import DataFrame as RemoteDataFrame
+
+if TYPE_CHECKING:
+    from pyspark.ml.wrapper import JavaWrapper, JavaEstimator
+    from pyspark.ml.util import JavaMLReadable, JavaMLWritable
+
+FuncT = TypeVar("FuncT", bound=Callable[..., Any])
+
+
+def try_remote_intermediate_result(f: FuncT) -> FuncT:
+    """Mark the function/property that returns the intermediate result of the 
remote call.
+    Eg, model.summary"""
+
+    @functools.wraps(f)
+    def wrapped(self: "JavaWrapper") -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            return f"{self._java_obj}.{f.__name__}"
+        else:
+            return f(self)
+
+    return cast(FuncT, wrapped)
+
+
+def try_remote_attribute_relation(f: FuncT) -> FuncT:
+    """Mark the function/property that returns a Relation.
+    Eg, model.summary.roc"""
+
+    @functools.wraps(f)
+    def wrapped(self: "JavaWrapper", *args: Any, **kwargs: Any) -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            # The attribute returns a dataframe, we need to wrap it
+            # in the _ModelAttributeRelationPlan
+            from pyspark.ml.remote.proto import _ModelAttributeRelationPlan
+            from pyspark.sql.connect.session import SparkSession
+
+            assert isinstance(self._java_obj, str)
+            plan = _ModelAttributeRelationPlan(self._java_obj, f.__name__)
+            session = SparkSession.getActiveSession()
+            assert session is not None
+            return RemoteDataFrame(plan, session)
+        else:
+            return f(self, *args, **kwargs)
+
+    return cast(FuncT, wrapped)
+
+
+def try_remote_fit(f: FuncT) -> FuncT:
+    """Mark the function that fits a model."""
+
+    @functools.wraps(f)
+    def wrapped(self: "JavaEstimator", dataset: RemoteDataFrame) -> Any:
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            client = dataset.sparkSession.client
+            input = dataset._plan.plan(client)
+            assert isinstance(self._java_obj, str)

Review Comment:
   Yeah, mainly for the typing issue.



##########
python/pyspark/ml/tests/test_persistence.py:
##########
@@ -153,29 +153,6 @@ def test_linear_regression_pmml_basic(self):
         self.assertIn("Apache Spark", pmml_text)
         self.assertIn("PMML", pmml_text)
 
-    def test_logistic_regression(self):

Review Comment:
   It's in test_classification.py. 
https://github.com/apache/spark/pull/48791/files#diff-a4da55d9f9f7dd55641f346cffe7ee87432f5ca06da40a75dddca514e6c6a648R236



##########
sql/connect/common/src/main/protobuf/spark/connect/base.proto:
##########
@@ -38,6 +39,7 @@ message Plan {
   oneof op_type {
     Relation root = 1;
     Command command = 2;
+    MlCommand ml_command = 3;
   }
 }

Review Comment:
   Really appreciate the valuable suggestions. Done.



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala:
##########
@@ -311,6 +313,17 @@ private[connect] class ExecuteThreadRunner(executeHolder: 
ExecuteHolder) extends
     planner.process(command = command, responseObserver = responseObserver)
   }
 
+  private def handleMLCommand(request: proto.ExecutePlanRequest): Unit = {

Review Comment:
   Yes, really appreciate the comments. after putting the ML commands into 
existing commands, life is becoming easier.



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala:
##########
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.ml
+
+import java.util.UUID
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.Model
+
+// TODO need to support persistence for model if memory is tight

Review Comment:
   Sounds good, will have one.



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala:
##########
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import scala.jdk.CollectionConverters.CollectionHasAsScala
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.Model
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.util.MLWritable
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.ml.MLUtils.loadModel
+import org.apache.spark.sql.connect.ml.Serializer.deserializeMethodArguments
+import org.apache.spark.sql.connect.service.SessionHolder
+
+private class ModelAttributeHelper(
+    val sessionHolder: SessionHolder,
+    val objIdentifier: String,
+    val method: Option[String],
+    val argValues: Array[Object] = Array.empty,
+    val argClasses: Array[Class[_]] = Array.empty) {
+
+  val methodChain = method.map(n => 
s"$objIdentifier.$n").getOrElse(objIdentifier)
+  private val methodChains = methodChain.split("\\.")
+  private val modelId = methodChains.head
+
+  private lazy val model = sessionHolder.mlCache.get(modelId)
+  private lazy val methods = methodChains.slice(1, methodChains.length)
+
+  def getAttribute: Any = {
+    assert(methods.length >= 1)
+    if (argValues.length == 0) {
+      methods.foldLeft(model.asInstanceOf[Object]) { (obj, attribute) =>
+        MLUtils.invokeMethodAllowed(obj, attribute)
+      }
+    } else {
+      val lastMethod = methods.last
+      if (methods.length == 1) {
+        MLUtils.invokeMethodAllowed(model.asInstanceOf[Object], lastMethod, 
argValues, argClasses)
+      } else {
+        val prevMethods = methods.slice(0, methods.length - 1)
+        val finalObj = prevMethods.foldLeft(model.asInstanceOf[Object]) { 
(obj, attribute) =>
+          MLUtils.invokeMethodAllowed(obj, attribute)
+        }
+        MLUtils.invokeMethodAllowed(finalObj, lastMethod, argValues, 
argClasses)
+      }
+    }
+  }
+
+  def transform(relation: proto.MlRelation.Transform): DataFrame = {
+    // Create a copied model to avoid concurrently modify model params.
+    val copiedModel = model.copy(ParamMap.empty).asInstanceOf[Model[_]]
+    MLUtils.setInstanceParams(copiedModel, relation.getParams)
+    val inputDF = MLUtils.parseRelationProto(relation.getInput, sessionHolder)
+    copiedModel.transform(inputDF)
+  }
+}
+
+private object ModelAttributeHelper {
+  def apply(
+      sessionHolder: SessionHolder,
+      modelId: String,
+      method: Option[String] = None,
+      args: Array[proto.FetchModelAttr.Args] = Array.empty): 
ModelAttributeHelper = {
+    val tmp = deserializeMethodArguments(args, sessionHolder)
+    val argValues = tmp.map(_._1)
+    val argClasses = tmp.map(_._2)
+    new ModelAttributeHelper(sessionHolder, modelId, method, argValues, 
argClasses)
+  }
+}
+
+object MLHandler extends Logging {

Review Comment:
   Added some documentation like "MLHandler is a utility to group all ML 
operations". 
   
   > Why can it be an object
   
   Not have a strong reason it must be an object. Any idea about it?



##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import java.util.ServiceLoader
+
+import scala.collection.immutable.HashSet
+import scala.jdk.CollectionConverters.{IterableHasAsScala, MapHasAsScala}
+
+import org.apache.commons.lang3.reflect.MethodUtils.{invokeMethod, 
invokeStaticMethod}
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.{Estimator, Model, Transformer}
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
+import org.apache.spark.ml.param.Params
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.util.Utils
+
+object MLUtils {
+
+  private lazy val estimators: Map[String, Class[_]] = {
+    val loader = Utils.getContextOrSparkClassLoader
+    val serviceLoader = ServiceLoader.load(classOf[Estimator[_]], loader)
+    val providers = serviceLoader.asScala.toList
+    providers.map(est => est.getClass.getName -> est.getClass).toMap
+  }
+
+  private lazy val transformers: Map[String, Class[_]] = {
+    val loader = Utils.getContextOrSparkClassLoader
+    val serviceLoader = ServiceLoader.load(classOf[Transformer], loader)
+    val providers = serviceLoader.asScala.toList
+    providers.map(est => est.getClass.getName -> est.getClass).toMap
+  }
+
+  def deserializeVector(vector: proto.Vector): Vector = {
+    if (vector.hasDense) {
+      val values = vector.getDense.getValueList.asScala.map(_.toDouble).toArray
+      Vectors.dense(values)
+    } else {
+      val size = vector.getSparse.getSize
+      val indices = vector.getSparse.getIndexList.asScala.map(_.toInt).toArray
+      val values = 
vector.getSparse.getValueList.asScala.map(_.toDouble).toArray
+      Vectors.sparse(size, indices, values)
+    }
+  }
+
+  def deserializeMatrix(matrix: proto.Matrix): Matrix = {
+    if (matrix.hasDense) {
+      val values = matrix.getDense.getValueList.asScala.map(_.toDouble).toArray
+      Matrices.dense(matrix.getDense.getNumRows, matrix.getDense.getNumCols, 
values)
+    } else {
+      val sparse = matrix.getSparse
+      val colPtrs = sparse.getColptrList.asScala.map(_.toInt).toArray
+      val rowIndices = sparse.getRowIndexList.asScala.map(_.toInt).toArray
+      val values = sparse.getValueList.asScala.map(_.toDouble).toArray
+      Matrices.sparse(sparse.getNumRows, sparse.getNumCols, colPtrs, 
rowIndices, values)
+    }
+  }
+
+  def setInstanceParams(instance: Params, params: proto.MlParams): Unit = {
+    params.getParamsMap.asScala.foreach { case (name, paramProto) =>
+      val p = instance.getParam(name)
+      val value = if (paramProto.hasLiteral) {
+        convertParamValue(
+          p.paramValueClassTag.runtimeClass,
+          LiteralValueProtoConverter.toCatalystValue(paramProto.getLiteral))
+      } else if (paramProto.hasVector) {
+        deserializeVector(paramProto.getVector)
+      } else if (paramProto.hasMatrix) {
+        deserializeMatrix(paramProto.getMatrix)
+      } else {
+        throw new RuntimeException("Unsupported parameter type")
+      }
+      instance.set(p, value)
+    }
+  }
+
+  private def convertArray(paramType: Class[_], array: Array[_]): Array[_] = {
+    if (paramType == classOf[Byte]) {
+      array.map(_.asInstanceOf[Byte])
+    } else if (paramType == classOf[Short]) {
+      array.map(_.asInstanceOf[Short])
+    } else if (paramType == classOf[Int]) {
+      array.map(_.asInstanceOf[Int])
+    } else if (paramType == classOf[Long]) {
+      array.map(_.asInstanceOf[Long])
+    } else if (paramType == classOf[Float]) {
+      array.map(_.asInstanceOf[Float])
+    } else if (paramType == classOf[Double]) {
+      array.map(_.asInstanceOf[Double])
+    } else if (paramType == classOf[String]) {
+      array.map(_.asInstanceOf[String])
+    } else {
+      array
+    }
+  }
+
+  private def convertParamValue(paramType: Class[_], value: Any): Any = {
+    // Some cases the param type might be mismatched with the value type.
+    // Because in python side we only have int / float type for numeric params.
+    // e.g.:
+    // param type is Int but client sends a Long type.
+    // param type is Long but client sends a Int type.
+    // param type is Float but client sends a Double type.
+    // param type is Array[Int] but client sends a Array[Long] type.
+    // param type is Array[Float] but client sends a Array[Double] type.
+    // param type is Array[Array[Int]] but client sends a Array[Array[Long]] 
type.
+    // param type is Array[Array[Float]] but client sends a 
Array[Array[Double]] type.
+    if (paramType == classOf[Byte]) {
+      value.asInstanceOf[java.lang.Number].byteValue()
+    } else if (paramType == classOf[Short]) {
+      value.asInstanceOf[java.lang.Number].shortValue()
+    } else if (paramType == classOf[Int]) {
+      value.asInstanceOf[java.lang.Number].intValue()
+    } else if (paramType == classOf[Long]) {
+      value.asInstanceOf[java.lang.Number].longValue()
+    } else if (paramType == classOf[Float]) {
+      value.asInstanceOf[java.lang.Number].floatValue()
+    } else if (paramType == classOf[Double]) {
+      value.asInstanceOf[java.lang.Number].doubleValue()
+    } else if (paramType.isArray) {
+      val compType = paramType.getComponentType
+      val array = value.asInstanceOf[Array[_]].map { e =>
+        convertParamValue(compType, e)
+      }
+      convertArray(compType, array)
+    } else {
+      value
+    }
+  }
+
+  def parseRelationProto(relation: proto.Relation, sessionHolder: 
SessionHolder): DataFrame = {
+    val planner = new SparkConnectPlanner(sessionHolder)
+    val plan = planner.transformRelation(relation)
+    Dataset.ofRows(sessionHolder.session, plan)
+  }
+
+  /**
+   * Get the Estimator instance according to the fit command
+   *
+   * @param fit
+   *   command
+   * @return
+   *   an Estimator
+   */
+  def getEstimator(fit: proto.MlCommand.Fit): Estimator[_] = {
+    // TODO support plugin
+    // Get the estimator according to the fit command
+    val name = fit.getEstimator.getName
+    if (estimators.isEmpty || !estimators.contains(name)) {
+      throw new RuntimeException(s"Failed to find estimator: $name")
+    }
+    val uid = fit.getEstimator.getUid
+    val estimator: Estimator[_] = estimators(name)
+      .getConstructor(classOf[String])
+      .newInstance(uid)
+      .asInstanceOf[Estimator[_]]
+
+    // Set parameters for the estimator
+    val params = fit.getParams
+    MLUtils.setInstanceParams(estimator, params)
+    estimator
+  }
+
+  def loadModel(className: String, path: String): Model[_] = {
+    // scalastyle:off classforname
+    val clazz = Class.forName(className)
+    // scalastyle:on classforname
+    val model = invokeStaticMethod(clazz, "load", path)
+    model.asInstanceOf[Model[_]]
+  }
+
+  /**
+   * Get the transformer instance according to the transform proto
+   *
+   * @param transformProto
+   *   transform proto
+   * @return
+   *   a Transformer
+   */
+  def getTransformer(transformProto: proto.MlRelation.Transform): Transformer 
= {
+    // Get the transformer name
+    val name = transformProto.getTransformer.getName
+    if (transformers.isEmpty || !transformers.contains(name)) {
+      throw new RuntimeException(s"Failed to find transformer: $name")
+    }
+    val uid = transformProto.getTransformer.getUid
+    val transformer = transformers(name)
+      .getConstructor(classOf[String])
+      .newInstance(uid)
+      .asInstanceOf[Transformer]
+
+    val params = transformProto.getParams
+    MLUtils.setInstanceParams(transformer, params)
+    transformer
+  }
+
+  private lazy val ALLOWED_ATTRIBUTES = HashSet(

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to