zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834912447
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##########
@@ -159,68 +161,78 @@ public IterationBodyResult process(
DenseVectorTypeInfo.INSTANCE),
new
SelectNearestCentroidOperator(distanceMeasure));
- AllWindowFunction<DenseVector, DenseVector[], TimeWindow> toList =
- new AllWindowFunction<DenseVector, DenseVector[],
TimeWindow>() {
- @Override
- public void apply(
- TimeWindow timeWindow,
- Iterable<DenseVector> iterable,
- Collector<DenseVector[]> out) {
- List<DenseVector> centroids =
IteratorUtils.toList(iterable.iterator());
- out.collect(centroids.toArray(new DenseVector[0]));
- }
- };
-
PerRoundSubBody perRoundSubBody =
new PerRoundSubBody() {
@Override
public DataStreamList process(DataStreamList inputs) {
DataStream<Tuple2<Integer, DenseVector>>
centroidIdAndPoints =
inputs.get(0);
- DataStream<DenseVector[]> newCentroids =
+ DataStream<KMeansModelData> modelDataStream =
centroidIdAndPoints
.map(new CountAppender())
.keyBy(t -> t.f0)
.window(EndOfStreamWindows.get())
.reduce(new CentroidAccumulator())
.map(new CentroidAverager())
.windowAll(EndOfStreamWindows.get())
- .apply(toList);
- return DataStreamList.of(newCentroids);
+ .apply(new ModelDataGenerator());
+ return DataStreamList.of(modelDataStream);
}
};
-
- DataStream<DenseVector[]> newCentroids =
+ DataStream<KMeansModelData> newModelData =
IterationBody.forEachRound(
DataStreamList.of(centroidIdAndPoints),
perRoundSubBody)
.get(0);
- DataStream<DenseVector[]> finalCentroids =
- newCentroids.flatMap(new ForwardInputsOfLastRound<>());
+
+ DataStream<DenseVector[]> newCentroids =
+ newModelData.map(x -> x.centroids).setParallelism(1);
+
+ DataStream<KMeansModelData> finalModelData =
+ newModelData.flatMap(new ForwardInputsOfLastRound<>());
return new IterationBodyResult(
DataStreamList.of(newCentroids),
- DataStreamList.of(finalCentroids),
+ DataStreamList.of(finalModelData),
terminationCriteria);
}
}
+ private static class ModelDataGenerator
+ implements AllWindowFunction<Tuple2<DenseVector, Double>,
KMeansModelData, TimeWindow> {
+ @Override
+ public void apply(
+ TimeWindow timeWindow,
+ Iterable<Tuple2<DenseVector, Double>> iterable,
+ Collector<KMeansModelData> collector) {
+ List<Tuple2<DenseVector, Double>> centroidsAndWeights =
Review comment:
Could we pass `k` (number of clusters) as a parameter for
`ModelDataGenerator`, such that we can avoid creating a list of centroids? This
could be more memory-efficient if `k` is large.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]