Github user mengxr commented on a diff in the pull request:
https://github.com/apache/spark/pull/6499#discussion_r32784743
--- Diff: python/pyspark/mllib/clustering.py ---
@@ -264,6 +270,190 @@ def train(cls, rdd, k, convergenceTol=1e-3,
maxIterations=100, seed=None, initia
return GaussianMixtureModel(weight, mvg_obj)
+class StreamingKMeansModel(KMeansModel):
+ """
+ .. note:: Experimental
+ Clustering model which can perform an online update of the centroids.
+
+ The update formula for each centroid is given by
+ c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
+ n_t+1 = n_t * a + m_t
+
+ where
+ c_t: Centroid at the n_th iteration.
+ n_t: Number of samples (or) weights associated with the centroid
+ at the n_th iteration.
+ x_t: Centroid of the new data closest to c_t.
+ m_t: Number of samples (or) weights of the new data closest to c_t
+ c_t+1: New centroid.
+ n_t+1: New number of weights.
+ a: Decay Factor, which gives the forgetfulness.
+
+ Note that if a is set to 1, it is the weighted mean of the previous
+ and new data. If it set to zero, the old centroids are completely
+ forgotten.
+
+ >>> initCenters, initWeights = [[0.0, 0.0], [1.0, 1.0]], [1.0, 1.0]
+ >>> stkm = StreamingKMeansModel(initCenters, initWeights)
+ >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
+ ... [0.9, 0.9], [1.1, 1.1]])
+ >>> stkm = stkm.update(data, 1.0, u"batches")
+ >>> stkm.centers
+ array([[ 0., 0.],
+ [ 1., 1.]])
+ >>> stkm.predict([-0.1, -0.1]) == stkm.predict([0.1, 0.1]) == 0
+ True
+ >>> stkm.predict([0.9, 0.9]) == stkm.predict([1.1, 1.1]) == 1
+ True
+ >>> stkm.clusterWeights
+ [3.0, 3.0]
+ >>> decayFactor = 0.0
+ >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2,
0.2])])
+ >>> stkm = stkm.update(data, 0.0, u"batches")
+ >>> stkm.centers
+ array([[ 0.2, 0.2],
+ [ 1.5, 1.5]])
+ >>> stkm.clusterWeights
+ [1.0, 1.0]
+ >>> stkm.predict([0.2, 0.2])
+ 0
+ >>> stkm.predict([1.5, 1.5])
+ 1
+
+ :param clusterCenters: Initial cluster centers.
+ :param clusterWeights: List of weights assigned to each cluster.
+ """
+ def __init__(self, clusterCenters, clusterWeights):
+ super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
+ self._clusterWeights = list(clusterWeights)
+
+ @property
+ def clusterWeights(self):
+ """Convenience method to return the cluster weights."""
+ return self._clusterWeights
+
+ @ignore_unicode_prefix
+ def update(self, data, decayFactor, timeUnit):
+ """Update the centroids, according to data
+
+ :param data: Should be a RDD that represents the new data.
+ :param decayFactor: forgetfulness of the previous centroids.
+ :param timeUnit: Can be "batches" or "points"
+
+ If points, then the decay factor is raised to the power of
+ number of new points and if batches, it is used as it is.
+ """
+ if not isinstance(data, RDD):
+ raise TypeError("data should be of a RDD, got %s." %
type(data))
--- End diff --
`data` -> `Data`
`a RDD` -> `an RDD`
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]