zhengruifeng commented on code in PR #49596:
URL: https://github.com/apache/spark/pull/49596#discussion_r1926275016
##########
python/pyspark/ml/connect/readwrite.py:
##########
@@ -37,52 +38,99 @@ def sc(self) -> "SparkContext":
raise RuntimeError("Accessing SparkContext is not supported on
Connect")
def save(self, path: str) -> None:
- from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
- from pyspark.ml.evaluation import JavaEvaluator
from pyspark.sql.connect.session import SparkSession
session = SparkSession.getActiveSession()
assert session is not None
+ RemoteMLWriter.saveInstance(
+ self._instance,
+ path,
+ session,
+ self.shouldOverwrite,
+ self.optionMap,
+ )
+
+ @staticmethod
+ def saveInstance(
+ instance: "JavaMLWritable",
+ path: str,
+ session: "SparkSession",
+ shouldOverwrite: bool = False,
+ optionMap: Dict[str, Any] = {},
+ ) -> None:
+ from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
+ from pyspark.ml.evaluation import JavaEvaluator
+ from pyspark.ml.pipeline import Pipeline, PipelineModel
+
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
- if isinstance(self._instance, JavaModel):
- model = cast("JavaModel", self._instance)
+ if isinstance(instance, JavaModel):
+ model = cast("JavaModel", instance)
params = serialize_ml_params(model, session.client)
assert isinstance(model._java_obj, str)
writer = pb2.MlCommand.Write(
obj_ref=pb2.ObjectRef(id=model._java_obj),
params=params,
path=path,
- should_overwrite=self.shouldOverwrite,
- options=self.optionMap,
+ should_overwrite=shouldOverwrite,
+ options=optionMap,
)
- else:
+ command = pb2.Command()
+ command.ml_command.write.CopyFrom(writer)
+ session.client.execute_command(command)
+
+ elif isinstance(instance, (JavaEstimator, JavaTransformer,
JavaEvaluator)):
operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
- if isinstance(self._instance, JavaEstimator):
+ if isinstance(instance, JavaEstimator):
ml_type = pb2.MlOperator.ESTIMATOR
- operator = cast("JavaEstimator", self._instance)
- elif isinstance(self._instance, JavaEvaluator):
+ operator = cast("JavaEstimator", instance)
+ elif isinstance(instance, JavaEvaluator):
ml_type = pb2.MlOperator.EVALUATOR
- operator = cast("JavaEvaluator", self._instance)
- elif isinstance(self._instance, JavaTransformer):
- ml_type = pb2.MlOperator.TRANSFORMER
- operator = cast("JavaTransformer", self._instance)
+ operator = cast("JavaEvaluator", instance)
else:
- raise NotImplementedError(f"Unsupported writing for
{self._instance}")
+ ml_type = pb2.MlOperator.TRANSFORMER
+ operator = cast("JavaTransformer", instance)
params = serialize_ml_params(operator, session.client)
assert isinstance(operator._java_obj, str)
writer = pb2.MlCommand.Write(
operator=pb2.MlOperator(name=operator._java_obj,
uid=operator.uid, type=ml_type),
params=params,
path=path,
- should_overwrite=self.shouldOverwrite,
- options=self.optionMap,
+ should_overwrite=shouldOverwrite,
+ options=optionMap,
)
- command = pb2.Command()
- command.ml_command.write.CopyFrom(writer)
- session.client.execute_command(command)
+ command = pb2.Command()
+ command.ml_command.write.CopyFrom(writer)
+ session.client.execute_command(command)
+
+ elif isinstance(instance, (Pipeline, PipelineModel)):
+ from pyspark.ml.pipeline import PipelineSharedReadWrite
+
+ if shouldOverwrite:
Review Comment:
let me skip the support of `shouldOverwrite` first, we will need to find a
better way
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]