This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a3cf9c5 [SPARK-30247][PYSPARK][FOLLOWUP] Add Python class
MultivariateGaussian
a3cf9c5 is described below
commit a3cf9c564e74effe0f8457eaf9835ca0d3ab8be3
Author: Huaxin Gao <[email protected]>
AuthorDate: Fri Dec 27 13:30:18 2019 +0800
[SPARK-30247][PYSPARK][FOLLOWUP] Add Python class MultivariateGaussian
### What changes were proposed in this pull request?
add a corresponding class MultivariateGaussian containing a vector and a
matrix on the py side, so gaussian can be used on the py side.
### Does this PR introduce any user-facing change?
add Python class ```MultivariateGaussian```
### How was this patch tested?
doctest
Closes #27020 from huaxingao/spark-30247.
Authored-by: Huaxin Gao <[email protected]>
Signed-off-by: zhengruifeng <[email protected]>
---
python/pyspark/ml/clustering.py | 36 +++++++++++++++++++++++++++++++++---
python/pyspark/ml/stat.py | 17 +++++++++++++++++
2 files changed, 50 insertions(+), 3 deletions(-)
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index f784b8f..7295b76 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -22,7 +22,8 @@ from pyspark import since, keyword_only
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams,
JavaWrapper
from pyspark.ml.param.shared import *
-from pyspark.ml.common import inherit_doc
+from pyspark.ml.common import inherit_doc, _java2py
+from pyspark.ml.stat import MultivariateGaussian
from pyspark.sql import DataFrame
__all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'BisectingKMeansSummary',
@@ -161,7 +162,11 @@ class GaussianMixtureModel(JavaModel,
_GaussianMixtureParams, JavaMLWritable, Ja
Array of :py:class:`MultivariateGaussian` where gaussians[i] represents
the Multivariate Gaussian (Normal) Distribution for Gaussian i
"""
- return self._call_java("gaussians")
+ sc = SparkContext._active_spark_context
+ jgaussians = self._java_obj.gaussians()
+ return [
+ MultivariateGaussian(_java2py(sc, jgaussian.mean()), _java2py(sc,
jgaussian.cov()))
+ for jgaussian in jgaussians]
@property
@since("2.0.0")
@@ -263,6 +268,21 @@ class GaussianMixture(JavaEstimator,
_GaussianMixtureParams, JavaMLWritable, Jav
>>> gaussians = model.gaussians
>>> len(gaussians)
3
+ >>> gaussians[0].mean
+ DenseVector([0.825, 0.8675])
+ >>> gaussians[0].cov.toArray()
+ array([[ 0.005625 , -0.0050625 ],
+ [-0.0050625 , 0.00455625]])
+ >>> gaussians[1].mean
+ DenseVector([-0.4777, -0.4096])
+ >>> gaussians[1].cov.toArray()
+ array([[ 0.1679695 , 0.13181786],
+ [ 0.13181786, 0.10524592]])
+ >>> gaussians[2].mean
+ DenseVector([-0.4473, -0.3853])
+ >>> gaussians[2].cov.toArray()
+ array([[ 0.16730412, 0.13112435],
+ [ 0.13112435, 0.10469614]])
>>> model.gaussiansDF.select("mean").head()
Row(mean=DenseVector([0.825, 0.8675]))
>>> model.gaussiansDF.select("cov").head()
@@ -285,7 +305,17 @@ class GaussianMixture(JavaEstimator,
_GaussianMixtureParams, JavaMLWritable, Jav
False
>>> model2.weights == model.weights
True
- >>> model2.gaussians == model.gaussians
+ >>> model2.gaussians[0].mean == model.gaussians[0].mean
+ True
+ >>> model2.gaussians[0].cov == model.gaussians[0].cov
+ True
+ >>> model2.gaussians[1].mean == model.gaussians[1].mean
+ True
+ >>> model2.gaussians[1].cov == model.gaussians[1].cov
+ True
+ >>> model2.gaussians[2].mean == model.gaussians[2].mean
+ True
+ >>> model2.gaussians[2].cov == model.gaussians[2].cov
True
>>> model2.gaussiansDF.select("mean").head()
Row(mean=DenseVector([0.825, 0.8675]))
diff --git a/python/pyspark/ml/stat.py b/python/pyspark/ml/stat.py
index 8f2eadd..53a57af 100644
--- a/python/pyspark/ml/stat.py
+++ b/python/pyspark/ml/stat.py
@@ -19,6 +19,7 @@ import sys
from pyspark import since, SparkContext
from pyspark.ml.common import _java2py, _py2java
+from pyspark.ml.linalg import DenseMatrix, Vectors
from pyspark.ml.wrapper import JavaWrapper, _jvm
from pyspark.sql.column import Column, _to_seq
from pyspark.sql.functions import lit
@@ -394,6 +395,22 @@ class SummaryBuilder(JavaWrapper):
return Column(self._java_obj.summary(featuresCol._jc, weightCol._jc))
+class MultivariateGaussian(object):
+ """Represents a (mean, cov) tuple
+
+ >>> m = MultivariateGaussian(Vectors.dense([11,12]), DenseMatrix(2, 2,
(1.0, 3.0, 5.0, 2.0)))
+ >>> (m.mean, m.cov.toArray())
+ (DenseVector([11.0, 12.0]), array([[ 1., 5.],
+ [ 3., 2.]]))
+
+ .. versionadded:: 3.0.0
+
+ """
+ def __init__(self, mean, cov):
+ self.mean = mean
+ self.cov = cov
+
+
if __name__ == "__main__":
import doctest
import numpy
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]