zhengruifeng commented on code in PR #49596:
URL: https://github.com/apache/spark/pull/49596#discussion_r1924918581
##########
python/pyspark/ml/connect/readwrite.py:
##########
@@ -37,52 +38,99 @@ def sc(self) -> "SparkContext":
raise RuntimeError("Accessing SparkContext is not supported on
Connect")
def save(self, path: str) -> None:
- from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
- from pyspark.ml.evaluation import JavaEvaluator
from pyspark.sql.connect.session import SparkSession
session = SparkSession.getActiveSession()
assert session is not None
+ RemoteMLWriter.saveInstance(
+ self._instance,
+ path,
+ session,
+ self.shouldOverwrite,
+ self.optionMap,
+ )
+
+ @staticmethod
+ def saveInstance(
+ instance: "JavaMLWritable",
+ path: str,
+ session: "SparkSession",
+ shouldOverwrite: bool = False,
+ optionMap: Dict[str, Any] = {},
+ ) -> None:
+ from pyspark.ml.wrapper import JavaModel, JavaEstimator,
JavaTransformer
+ from pyspark.ml.evaluation import JavaEvaluator
+ from pyspark.ml.pipeline import Pipeline, PipelineModel
+
# Spark Connect ML is built on scala Spark.ML, that means we're only
# supporting JavaModel or JavaEstimator or JavaEvaluator
- if isinstance(self._instance, JavaModel):
- model = cast("JavaModel", self._instance)
+ if isinstance(instance, JavaModel):
+ model = cast("JavaModel", instance)
params = serialize_ml_params(model, session.client)
assert isinstance(model._java_obj, str)
writer = pb2.MlCommand.Write(
obj_ref=pb2.ObjectRef(id=model._java_obj),
params=params,
path=path,
- should_overwrite=self.shouldOverwrite,
- options=self.optionMap,
+ should_overwrite=shouldOverwrite,
+ options=optionMap,
)
- else:
+ command = pb2.Command()
+ command.ml_command.write.CopyFrom(writer)
+ session.client.execute_command(command)
+
+ elif isinstance(instance, (JavaEstimator, JavaTransformer,
JavaEvaluator)):
operator: Union[JavaEstimator, JavaTransformer, JavaEvaluator]
- if isinstance(self._instance, JavaEstimator):
+ if isinstance(instance, JavaEstimator):
ml_type = pb2.MlOperator.ESTIMATOR
- operator = cast("JavaEstimator", self._instance)
- elif isinstance(self._instance, JavaEvaluator):
+ operator = cast("JavaEstimator", instance)
+ elif isinstance(instance, JavaEvaluator):
ml_type = pb2.MlOperator.EVALUATOR
- operator = cast("JavaEvaluator", self._instance)
- elif isinstance(self._instance, JavaTransformer):
- ml_type = pb2.MlOperator.TRANSFORMER
- operator = cast("JavaTransformer", self._instance)
+ operator = cast("JavaEvaluator", instance)
else:
- raise NotImplementedError(f"Unsupported writing for
{self._instance}")
+ ml_type = pb2.MlOperator.TRANSFORMER
+ operator = cast("JavaTransformer", instance)
params = serialize_ml_params(operator, session.client)
assert isinstance(operator._java_obj, str)
writer = pb2.MlCommand.Write(
operator=pb2.MlOperator(name=operator._java_obj,
uid=operator.uid, type=ml_type),
params=params,
path=path,
- should_overwrite=self.shouldOverwrite,
- options=self.optionMap,
+ should_overwrite=shouldOverwrite,
+ options=optionMap,
)
- command = pb2.Command()
- command.ml_command.write.CopyFrom(writer)
- session.client.execute_command(command)
+ command = pb2.Command()
+ command.ml_command.write.CopyFrom(writer)
+ session.client.execute_command(command)
+
+ elif isinstance(instance, (Pipeline, PipelineModel)):
+ from pyspark.ml.pipeline import PipelineSharedReadWrite
+
+ if shouldOverwrite:
Review Comment:
it is kind of tricky here, we don't have a method to remove the path in
connect
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]