wbo4958 commented on code in PR #48791:
URL: https://github.com/apache/spark/pull/48791#discussion_r1877882367
##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala:
##########
@@ -111,6 +112,9 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
private[spark] lazy val dataFrameCache: ConcurrentMap[String, DataFrame] =
new ConcurrentHashMap()
+ // ML model cache
+ private[connect] lazy val mlCache = new MLCache()
Review Comment:
Currently, I have no idea about the size of the objects, some object should
be small, while some are large, it depends.
Yes, the de-constructor of the python class will trigger a GRPC to delete
the cache on the server side.
##########
mllib/src/main/scala/org/apache/spark/ml/util/Summary.scala:
##########
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.annotation.Since
+
+/**
+ * Trait for the Summary
+ */
+@Since("4.0.0")
+private[spark] trait Summary
Review Comment:
Make all the summaries extend from this Summary, so in connect, we can match
to this trait and decide if the object of the summary needs to be cached.
##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala:
##########
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.ml
+
+import java.util.UUID
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.spark.internal.Logging
+
+/**
+ * MLCache is for caching ML objects, typically for models and summaries
evaluated by a model.
+ */
+private[connect] class MLCache extends Logging {
+ private val cachedModel: ConcurrentHashMap[String, Object] =
+ new ConcurrentHashMap[String, Object]()
+
+ /**
+ * Cache an object into a map of MLCache, and return its key
+ * @param obj
+ * the object to be cached
+ * @return
+ * the key
+ */
+ def register(obj: Object): String = {
+ val objectId = UUID.randomUUID().toString.takeRight(12)
Review Comment:
Hmm, we could use the full UUID string instead of just taking the 12 chars
from right.
> An alternative would be to generate a predictable Id.
Could you show me how to do that?
##########
sql/connect/common/src/main/protobuf/spark/connect/ml_common.proto:
##########
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+syntax = 'proto3';
+
+package spark.connect;
+
+import "spark/connect/expressions.proto";
+
+option java_multiple_files = true;
+option java_package = "org.apache.spark.connect.proto";
+
+// MlParams stores param settings for ML Estimator / Transformer / Evaluator
+message MlParams {
+ // User-supplied params
+ map<string, Param> params = 1;
+}
+
+// Represents the parameter type of the ML instance, or the returned value
+// of the attribute
+message Param {
+ oneof param_type {
+ Expression.Literal literal = 1;
+ Vector vector = 2;
+ Matrix matrix = 3;
+ }
+}
+
+// MLOperator represents the ML operators like (Estimator, Transformer or
Evaluator)
+message MlOperator {
+ // The qualified name of the ML operator.
+ string name = 1;
+ // Unique id of the ML operator
+ string uid = 2;
Review Comment:
No, the UID is the Identifiable for each ML operators.
##########
sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala:
##########
@@ -0,0 +1,293 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connect.ml
+
+import java.util.ServiceLoader
+
+import scala.collection.immutable.HashSet
+import scala.jdk.CollectionConverters.{IterableHasAsScala, MapHasAsScala}
+
+import org.apache.commons.lang3.reflect.MethodUtils.{invokeMethod,
invokeStaticMethod}
+
+import org.apache.spark.connect.proto
+import org.apache.spark.ml.{Estimator, Transformer}
+import org.apache.spark.ml.linalg.{Matrices, Matrix, Vector, Vectors}
+import org.apache.spark.ml.param.Params
+import org.apache.spark.ml.util.MLWritable
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.connect.common.LiteralValueProtoConverter
+import org.apache.spark.sql.connect.planner.SparkConnectPlanner
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.util.Utils
+
+private[ml] object MLUtils {
+
+ private lazy val estimators: Map[String, Class[_]] = {
Review Comment:
Sounds good. Done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]