This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b0e18ba71c98 [SPARK-51003][ML][PYTHON][CONNECT] Support LSH models on
Connect
b0e18ba71c98 is described below
commit b0e18ba71c98dfcc38c41e1a10b492f7d742bec7
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 27 18:23:43 2025 +0800
[SPARK-51003][ML][PYTHON][CONNECT] Support LSH models on Connect
### What changes were proposed in this pull request?
Support LSH models on Connect
### Why are the changes needed?
for feature parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added tests
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #49689 from zhengruifeng/ml_connect_lsh.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../services/org.apache.spark.ml.Estimator | 2 +
.../services/org.apache.spark.ml.Transformer | 2 +
.../ml/feature/BucketedRandomProjectionLSH.scala | 2 +
.../scala/org/apache/spark/ml/feature/LSH.scala | 2 +-
.../org/apache/spark/ml/feature/MinHashLSH.scala | 2 +
python/pyspark/ml/feature.py | 2 +
python/pyspark/ml/tests/test_feature.py | 110 +++++++++++++++++++++
.../org/apache/spark/sql/connect/ml/MLUtils.scala | 1 +
8 files changed, 122 insertions(+), 1 deletion(-)
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index 61338f561868..26ea0fe5a00e 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -66,3 +66,5 @@ org.apache.spark.ml.feature.Word2Vec
org.apache.spark.ml.feature.CountVectorizer
org.apache.spark.ml.feature.OneHotEncoder
org.apache.spark.ml.feature.TargetEncoder
+org.apache.spark.ml.feature.BucketedRandomProjectionLSH
+org.apache.spark.ml.feature.MinHashLSH
diff --git
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index 04cde68ec806..06375a701010 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -83,3 +83,5 @@ org.apache.spark.ml.feature.Word2VecModel
org.apache.spark.ml.feature.CountVectorizerModel
org.apache.spark.ml.feature.OneHotEncoderModel
org.apache.spark.ml.feature.TargetEncoderModel
+org.apache.spark.ml.feature.BucketedRandomProjectionLSHModel
+org.apache.spark.ml.feature.MinHashLSHModel
diff --git
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 537cb5020c88..5037ac941afb 100644
---
a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++
b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -68,6 +68,8 @@ class BucketedRandomProjectionLSHModel private[ml](
private[ml] val randMatrix: Matrix)
extends LSHModel[BucketedRandomProjectionLSHModel] with
BucketedRandomProjectionLSHParams {
+ private[ml] def this() = this(Identifiable.randomUID("brp-lsh"),
Matrices.empty)
+
private[ml] def this(uid: String, randUnitVectors: Array[Vector]) = {
this(uid, Matrices.fromVectors(randUnitVectors.toImmutableArraySeq))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 2515365a6a3c..9c3b39b12bdc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -60,7 +60,7 @@ private[ml] trait LSHParams extends HasInputCol with
HasOutputCol {
/**
* Model produced by [[LSH]].
*/
-private[ml] abstract class LSHModel[T <: LSHModel[T]]
+private[spark] abstract class LSHModel[T <: LSHModel[T]]
extends Model[T] with LSHParams with MLWritable {
self: T =>
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 3f2a3327128a..d077b0a4a022 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -48,6 +48,8 @@ class MinHashLSHModel private[ml](
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {
+ private[ml] def this() = this(Identifiable.randomUID("mh-lsh"), Array.empty)
+
/** @group setParam */
@Since("2.4.0")
override def setInputCol(value: String): this.type = super.set(inputCol,
value)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 4cc45c1bf194..81f6c7ebcbdf 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -387,6 +387,7 @@ class _LSHModel(JavaModel, _LSHParams):
"""
return self._set(outputCol=value)
+ @try_remote_attribute_relation
def approxNearestNeighbors(
self,
dataset: DataFrame,
@@ -424,6 +425,7 @@ class _LSHModel(JavaModel, _LSHParams):
"""
return self._call_java("approxNearestNeighbors", dataset, key,
numNearestNeighbors, distCol)
+ @try_remote_attribute_relation
def approxSimilarityJoin(
self,
datasetA: DataFrame,
diff --git a/python/pyspark/ml/tests/test_feature.py
b/python/pyspark/ml/tests/test_feature.py
index d7bd5ef4a1fc..ee6f8a78cc4c 100644
--- a/python/pyspark/ml/tests/test_feature.py
+++ b/python/pyspark/ml/tests/test_feature.py
@@ -68,6 +68,10 @@ from pyspark.ml.feature import (
PCAModel,
Word2Vec,
Word2VecModel,
+ BucketedRandomProjectionLSH,
+ BucketedRandomProjectionLSHModel,
+ MinHashLSH,
+ MinHashLSHModel,
)
from pyspark.ml.linalg import DenseVector, SparseVector, Vectors
from pyspark.sql import Row
@@ -1342,6 +1346,112 @@ class FeatureTestsMixin:
tf2 = HashingTF.load(d)
self.assertEqual(str(tf), str(tf2))
+ def test_bucketed_random_projection_lsh(self):
+ spark = self.spark
+
+ data = [
+ (0, Vectors.dense([-1.0, -1.0])),
+ (1, Vectors.dense([-1.0, 1.0])),
+ (2, Vectors.dense([1.0, -1.0])),
+ (3, Vectors.dense([1.0, 1.0])),
+ ]
+ df = spark.createDataFrame(data, ["id", "features"])
+
+ data2 = [
+ (4, Vectors.dense([2.0, 2.0])),
+ (5, Vectors.dense([2.0, 3.0])),
+ (6, Vectors.dense([3.0, 2.0])),
+ (7, Vectors.dense([3.0, 3.0])),
+ ]
+ df2 = spark.createDataFrame(data2, ["id", "features"])
+
+ brp = BucketedRandomProjectionLSH()
+ brp.setInputCol("features")
+ brp.setOutputCol("hashes")
+ brp.setSeed(12345)
+ brp.setBucketLength(1.0)
+
+ self.assertEqual(brp.getInputCol(), "features")
+ self.assertEqual(brp.getOutputCol(), "hashes")
+ self.assertEqual(brp.getBucketLength(), 1.0)
+ self.assertEqual(brp.getSeed(), 12345)
+
+ model = brp.fit(df)
+
+ output = model.transform(df)
+ self.assertEqual(output.columns, ["id", "features", "hashes"])
+ self.assertEqual(output.count(), 4)
+
+ output = model.approxNearestNeighbors(df2, Vectors.dense([1.0, 2.0]),
1)
+ self.assertEqual(output.columns, ["id", "features", "hashes",
"distCol"])
+ self.assertEqual(output.count(), 1)
+
+ output = model.approxSimilarityJoin(df, df2, 3)
+ self.assertEqual(output.columns, ["datasetA", "datasetB", "distCol"])
+ self.assertEqual(output.count(), 1)
+
+ # save & load
+ with
tempfile.TemporaryDirectory(prefix="bucketed_random_projection_lsh") as d:
+ brp.write().overwrite().save(d)
+ brp2 = BucketedRandomProjectionLSH.load(d)
+ self.assertEqual(str(brp), str(brp2))
+
+ model.write().overwrite().save(d)
+ model2 = BucketedRandomProjectionLSHModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
+ def test_min_hash_lsh(self):
+ spark = self.spark
+
+ data = [
+ (0, Vectors.dense([-1.0, -1.0])),
+ (1, Vectors.dense([-1.0, 1.0])),
+ (2, Vectors.dense([1.0, -1.0])),
+ (3, Vectors.dense([1.0, 1.0])),
+ ]
+ df = spark.createDataFrame(data, ["id", "features"])
+
+ data2 = [
+ (4, Vectors.dense([2.0, 2.0])),
+ (5, Vectors.dense([2.0, 3.0])),
+ (6, Vectors.dense([3.0, 2.0])),
+ (7, Vectors.dense([3.0, 3.0])),
+ ]
+ df2 = spark.createDataFrame(data2, ["id", "features"])
+
+ mh = MinHashLSH()
+ mh.setInputCol("features")
+ mh.setOutputCol("hashes")
+ mh.setSeed(12345)
+
+ self.assertEqual(mh.getInputCol(), "features")
+ self.assertEqual(mh.getOutputCol(), "hashes")
+ self.assertEqual(mh.getSeed(), 12345)
+
+ model = mh.fit(df)
+
+ output = model.transform(df)
+ self.assertEqual(output.columns, ["id", "features", "hashes"])
+ self.assertEqual(output.count(), 4)
+
+ output = model.approxNearestNeighbors(df2, Vectors.dense([1.0, 2.0]),
1)
+ self.assertEqual(output.columns, ["id", "features", "hashes",
"distCol"])
+ self.assertEqual(output.count(), 1)
+
+ output = model.approxSimilarityJoin(df, df2, 3)
+ self.assertEqual(output.columns, ["datasetA", "datasetB", "distCol"])
+ self.assertEqual(output.count(), 16)
+
+ # save & load
+ with tempfile.TemporaryDirectory(prefix="min_hash_lsh") as d:
+ mh.write().overwrite().save(d)
+ mh2 = MinHashLSH.load(d)
+ self.assertEqual(str(mh), str(mh2))
+
+ model.write().overwrite().save(d)
+ model2 = MinHashLSHModel.load(d)
+ self.assertEqual(str(model), str(model2))
+
class FeatureTests(FeatureTestsMixin, SparkSessionTestCase):
pass
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index 56526b7e6737..be067ed8972e 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -458,6 +458,7 @@ private[ml] object MLUtils {
(classOf[PredictionModel[_, _]], Set("predict", "numFeatures")),
(classOf[ClassificationModel[_, _]], Set("predictRaw", "numClasses")),
(classOf[ProbabilisticClassificationModel[_, _]],
Set("predictProbability")),
+ (classOf[LSHModel[_]], Set("approxNearestNeighbors",
"approxSimilarityJoin")),
// Summary Traits
(classOf[HasTrainingSummary[_]], Set("hasSummary", "summary")),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]