zhipeng93 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r834912599



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##########
@@ -159,68 +161,78 @@ public IterationBodyResult process(
                                             DenseVectorTypeInfo.INSTANCE),
                                     new 
SelectNearestCentroidOperator(distanceMeasure));
 
-            AllWindowFunction<DenseVector, DenseVector[], TimeWindow> toList =
-                    new AllWindowFunction<DenseVector, DenseVector[], 
TimeWindow>() {
-                        @Override
-                        public void apply(
-                                TimeWindow timeWindow,
-                                Iterable<DenseVector> iterable,
-                                Collector<DenseVector[]> out) {
-                            List<DenseVector> centroids = 
IteratorUtils.toList(iterable.iterator());
-                            out.collect(centroids.toArray(new DenseVector[0]));
-                        }
-                    };
-
             PerRoundSubBody perRoundSubBody =
                     new PerRoundSubBody() {
                         @Override
                         public DataStreamList process(DataStreamList inputs) {
                             DataStream<Tuple2<Integer, DenseVector>> 
centroidIdAndPoints =
                                     inputs.get(0);
-                            DataStream<DenseVector[]> newCentroids =
+                            DataStream<KMeansModelData> modelDataStream =
                                     centroidIdAndPoints
                                             .map(new CountAppender())
                                             .keyBy(t -> t.f0)
                                             .window(EndOfStreamWindows.get())
                                             .reduce(new CentroidAccumulator())
                                             .map(new CentroidAverager())
                                             
.windowAll(EndOfStreamWindows.get())
-                                            .apply(toList);
-                            return DataStreamList.of(newCentroids);
+                                            .apply(new ModelDataGenerator());
+                            return DataStreamList.of(modelDataStream);
                         }
                     };
-
-            DataStream<DenseVector[]> newCentroids =
+            DataStream<KMeansModelData> newModelData =
                     IterationBody.forEachRound(
                                     DataStreamList.of(centroidIdAndPoints), 
perRoundSubBody)
                             .get(0);
-            DataStream<DenseVector[]> finalCentroids =
-                    newCentroids.flatMap(new ForwardInputsOfLastRound<>());
+
+            DataStream<DenseVector[]> newCentroids =
+                    newModelData.map(x -> x.centroids).setParallelism(1);
+
+            DataStream<KMeansModelData> finalModelData =
+                    newModelData.flatMap(new ForwardInputsOfLastRound<>());
 
             return new IterationBodyResult(
                     DataStreamList.of(newCentroids),
-                    DataStreamList.of(finalCentroids),
+                    DataStreamList.of(finalModelData),
                     terminationCriteria);
         }
     }
 
+    private static class ModelDataGenerator
+            implements AllWindowFunction<Tuple2<DenseVector, Double>, 
KMeansModelData, TimeWindow> {
+        @Override
+        public void apply(
+                TimeWindow timeWindow,
+                Iterable<Tuple2<DenseVector, Double>> iterable,
+                Collector<KMeansModelData> collector) {
+            List<Tuple2<DenseVector, Double>> centroidsAndWeights =

Review comment:
       Could we pass `k` (number of clusters) as a parameter for 
`ModelDataGenerator`, such that we can avoid creating a list of centroids? This 
could be more memory-efficient if `k` is large.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
##########
@@ -47,8 +47,16 @@
 
     public DenseVector[] centroids;
 
+    public DenseVector weights;

Review comment:
       nits: Could we add java docs to explain why we add `weights` here?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;

Review comment:
       nits: Is `localBatchDataState` a better name?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after 
dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table onlineTrainTable;
+    private Table onlinePredictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), 
schema).as("features");
+        onlineTrainTable =
+                tEnv.fromDataStream(
+                                env.addSource(trainSource, 
DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+        onlinePredictTable =
+                tEnv.fromDataStream(
+                                env.addSource(predictSource, 
DenseVectorTypeInfo.INSTANCE), schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() throws Exception {
+        miniCluster.close();
+    }
+
+    /**
+     * Performs transform() on the provided model with predictTable, and adds 
sinks for
+     * OnlineKMeansModel's transform output and model data.
+     */
+    private void transformAndOutputData(OnlineKMeansModel onlineModel) {
+        Table outputTable = onlineModel.transform(onlinePredictTable)[0];
+        tEnv.toDataStream(outputTable).addSink(outputSink);
+
+        Table modelDataTable = onlineModel.getModelData()[0];
+        
KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink);
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() throws InterruptedException {
+        while 
(reporter.findMetrics(OnlineKMeansModel.MODEL_DATA_VERSION_GAUGE_KEY).size()
+                < defaultParallelism) {
+            Thread.sleep(100);
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next 
model-data-update event. */
+    @SuppressWarnings("unchecked")
+    private void waitModelDataUpdate() throws InterruptedException {
+        do {
+            int tmpModelDataVersion =
+                    
reporter.findMetrics(OnlineKMeansModel.MODEL_DATA_VERSION_GAUGE_KEY).values()
+                            .stream()
+                            .map(x -> Integer.parseInt(((Gauge<String>) 
x).getValue()))
+                            .min(Integer::compareTo)
+                            .get();
+            if (tmpModelDataVersion == currentModelDataVersion) {
+                Thread.sleep(100);
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the 
prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the 
expected group result
+     * @param featuresCol Name of the column in the table that contains the 
features
+     * @param predictionCol Name of the column in the table that contains the 
prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String 
predictionCol)
+            throws Exception {
+        predictSource.addAll(OnlineKMeansTest.predictData);
+        List<Row> rawResult = 
outputSink.poll(OnlineKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, 
predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        OnlineKMeans onlineKMeans = new OnlineKMeans();
+        Assert.assertEquals("features", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", onlineKMeans.getPredictionCol());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, 
onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals(32, onlineKMeans.getGlobalBatchSize());
+        Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), 
onlineKMeans.getSeed());
+
+        onlineKMeans
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setGlobalBatchSize(5)
+                .setDecayFactor(0.25)
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", 
onlineKMeans.getPredictionCol());
+        Assert.assertEquals("count", onlineKMeans.getBatchStrategy());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, 
onlineKMeans.getDistanceMeasure());
+        Assert.assertEquals(5, onlineKMeans.getGlobalBatchSize());
+        Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals(100, onlineKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        OnlineKMeans onlineKMeans =
+                new OnlineKMeans()
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction")
+                        .setGlobalBatchSize(6)
+                        .setRandomCentroids(2, 0.);

Review comment:
       nits: shall we change `0.` to `0.0`? Just to ensure that the code style 
is consistent.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchStrategy.java
##########
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared batch strategy param. */
+public interface HasBatchStrategy<T> extends WithParams<T> {
+    String COUNT_STRATEGY = "count";

Review comment:
       How about explain a bit about what is `count`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay

Review comment:
       nits: decay factor -> the decay factor

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/common/distance/EuclideanDistanceMeasure.java
##########
@@ -34,6 +35,7 @@ public static EuclideanDistanceMeasure getInstance() {
 
     @Override
     public double distance(Vector v1, Vector v2) {
+        Preconditions.checkArgument(v1.size() == v2.size());
         double squaredDistance = 0.0;
 
         for (int i = 0; i < v1.size(); i++) {

Review comment:
       nit: do you think using BLAS here is more efficient? It is okay to leave 
it as it is, since it is not part of this PR.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);

Review comment:
       nits: is `onlineKmeans` a better name here?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =

Review comment:
       One quick question: should we make it `static` or not?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =

Review comment:
       nits: Could `finalModelData` be renamed to `onlineModelData`? It is not 
really `the final model data`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, 
weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, 
DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, 
and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> 
collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();

Review comment:
       Would the following code be more readable?
   
   ```
   int div = values.length / downStreamParallelism;
   int mod = values.length % downStreamParallelism;
   int offset = 0;
   for (int i = 0; i < downStreamParallelism; i ++) {
           int size = i >= mod ? div: div + 1;
           collector.collect(Arrays.copyOfRange(values, offset, offset + size));
   }
   ```

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;

Review comment:
       nit: this line could be removed.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} 
operator which can update
+ * model data in a streaming format, using the model data provided by {@link 
OnlineKMeans}.
+ */
+public class OnlineKMeansModel
+        implements Model<OnlineKMeansModel>, 
KMeansModelParams<OnlineKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public OnlineKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        
DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends 
CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model 
data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> bufferedPoints = new ArrayList<>();

Review comment:
       One stupid question: Is the `bufferedPoints` to be checkpointed when 
doing snapshot?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];

Review comment:
       nit: this could be a BLAS operation.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, 
weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, 
DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, 
and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> 
collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + 
size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], 
GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = 
IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the 
provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with 
randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(

Review comment:
       If the users do not provide an init model, could we follow `Kmeans` and 
randomly initialize the model data? 
   
   Exposing this method to end users seems a bit unnecessary to me.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,437 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.function.Supplier;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+
+        DataStream<KMeansModelData> initModelData =
+                KMeansModelData.getModelDataStream(initModelDataTable);
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new OnlineKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getGlobalBatchSize());
+
+        DataStream<KMeansModelData> finalModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelData);
+        OnlineKMeansModel model = new 
OnlineKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Saves the metadata AND bounded model data table (if exists) to the 
given path. */
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static OnlineKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        String initModelDataPath = ReadWriteUtils.getDataPath(path);
+        if (Files.exists(Paths.get(initModelDataPath))) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class OnlineKMeansIterationBody implements IterationBody {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+
+        public OnlineKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<KMeansModelData> modelData = variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+
+            DataStream<KMeansModelData> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(KMeansModelData.class),
+                                    new ModelDataLocalUpdater(distanceMeasure, 
decayFactor))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new ModelDataGlobalReducer());
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData), 
DataStreamList.of(modelData));
+        }
+    }
+
+    private static class ModelDataGlobalReducer implements 
ReduceFunction<KMeansModelData> {
+        @Override
+        public KMeansModelData reduce(KMeansModelData modelData, 
KMeansModelData newModelData) {
+            DenseVector weights = modelData.weights;
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector newWeights = newModelData.weights;
+            DenseVector[] newCentroids = newModelData.centroids;
+
+            int k = newCentroids.length;
+            int dim = newCentroids[0].size();
+
+            for (int i = 0; i < k; i++) {
+                for (int j = 0; j < dim; j++) {
+                    centroids[i].values[j] =
+                            (centroids[i].values[j] * weights.values[i]
+                                            + newCentroids[i].values[j] * 
newWeights.values[i])
+                                    / Math.max(weights.values[i] + 
newWeights.values[i], 1e-16);
+                }
+                weights.values[i] += newWeights.values[i];
+            }
+
+            return new KMeansModelData(centroids, weights);
+        }
+    }
+
+    private static class ModelDataLocalUpdater extends 
AbstractStreamOperator<KMeansModelData>
+            implements TwoInputStreamOperator<DenseVector[], KMeansModelData, 
KMeansModelData> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private ListState<DenseVector[]> localBatchState;
+        private ListState<KMeansModelData> modelDataState;
+
+        private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, double 
decayFactor) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> pointsRecord) 
throws Exception {
+            localBatchState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<KMeansModelData> 
modelDataRecord)
+                throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            KMeansModelData modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData")
+                            .orElseThrow((Supplier<Exception>) 
NullPointerException::new);
+            DenseVector[] centroids = modelData.centroids;
+            DenseVector weights = modelData.weights;
+            modelDataState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(localBatchState.get().iterator());
+            DenseVector[] points = pointsList.remove(0);
+            localBatchState.update(pointsList);
+
+            int dim = centroids[0].size();
+            int k = centroids.length;
+            int parallelism = 
getRuntimeContext().getNumberOfParallelSubtasks();
+
+            // Computes new centroids.
+            DenseVector[] sums = new DenseVector[k];
+            int[] counts = new int[k];
+
+            for (int i = 0; i < k; i++) {
+                sums[i] = new DenseVector(dim);
+                counts[i] = 0;
+            }
+            for (DenseVector point : points) {
+                int closestCentroidId =
+                        KMeans.findClosestCentroidId(centroids, point, 
distanceMeasure);
+                counts[closestCentroidId]++;
+                for (int j = 0; j < dim; j++) {
+                    sums[closestCentroidId].values[j] += point.values[j];
+                }
+            }
+
+            // Considers weight and decay factor when updating centroids.
+            BLAS.scal(decayFactor / parallelism, weights);
+            for (int i = 0; i < k; i++) {
+                if (counts[i] == 0) {
+                    continue;
+                }
+
+                DenseVector centroid = centroids[i];
+                weights.values[i] = weights.values[i] + counts[i];
+                double lambda = counts[i] / weights.values[i];
+
+                BLAS.scal(1.0 - lambda, centroid);
+                BLAS.axpy(lambda / counts[i], sums[i], centroid);
+            }
+
+            output.collect(new StreamRecord<>(new KMeansModelData(centroids, 
weights)));
+        }
+    }
+
+    private static class FeaturesExtractor implements MapFunction<Row, 
DenseVector> {
+        private final String featuresCol;
+
+        private FeaturesExtractor(String featuresCol) {
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public DenseVector map(Row row) throws Exception {
+            return (DenseVector) row.getField(featuresCol);
+        }
+    }
+
+    // An operator that splits a global batch into evenly-sized local batches, 
and distributes them
+    // to downstream operator.
+    private static class GlobalBatchSplitter
+            implements FlatMapFunction<DenseVector[], DenseVector[]> {
+        private final int downStreamParallelism;
+
+        private GlobalBatchSplitter(int downStreamParallelism) {
+            this.downStreamParallelism = downStreamParallelism;
+        }
+
+        @Override
+        public void flatMap(DenseVector[] values, Collector<DenseVector[]> 
collector) {
+            // Calculate the batch sizes to be distributed on each subtask.
+            List<Integer> sizes = new ArrayList<>();
+            for (int i = 0; i < downStreamParallelism; i++) {
+                int start = i * values.length / downStreamParallelism;
+                int end = (i + 1) * values.length / downStreamParallelism;
+                sizes.add(end - start);
+            }
+
+            int offset = 0;
+            for (Integer size : sizes) {
+                collector.collect(Arrays.copyOfRange(values, offset, offset + 
size));
+                offset += size;
+            }
+        }
+    }
+
+    private static class GlobalBatchCreator
+            implements AllWindowFunction<DenseVector, DenseVector[], 
GlobalWindow> {
+        @Override
+        public void apply(
+                GlobalWindow timeWindow,
+                Iterable<DenseVector> iterable,
+                Collector<DenseVector[]> collector) {
+            List<DenseVector> points = 
IteratorUtils.toList(iterable.iterator());
+            collector.collect(points.toArray(new DenseVector[0]));
+        }
+    }
+
+    /**
+     * Sets the initial model data of the online training process with the 
provided model data
+     * table.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     */
+    public OnlineKMeans setInitialModelData(Table initModelDataTable) {
+        this.initModelDataTable = initModelDataTable;
+        return this;
+    }
+
+    /**
+     * Sets the initial model data of the online training process with 
randomly created centroids.
+     *
+     * <p>This would override the effect of previously invoked {@link 
#setInitialModelData(Table)}
+     * or {@link #setRandomCentroids(StreamTableEnvironment, int, int, 
double)}.
+     *
+     * @param tEnv The stream table environment to create the centroids in.
+     * @param dim The dimension of the centroids to create.
+     * @param k The number of centroids to create.
+     * @param weight The weight of the centroids to create.
+     */
+    public OnlineKMeans setRandomCentroids(
+            StreamTableEnvironment tEnv, int dim, int k, double weight) {

Review comment:
       +1

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java
##########
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.RestOptions;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeans;
+import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.InMemorySinkFunction;
+import org.apache.flink.ml.util.InMemorySourceFunction;
+import org.apache.flink.runtime.minicluster.MiniCluster;
+import org.apache.flink.runtime.minicluster.MiniClusterConfiguration;
+import org.apache.flink.runtime.testutils.InMemoryReporter;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */
+public class OnlineKMeansTest extends TestLogger {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private static final int defaultParallelism = 4;
+    private static final int numTaskManagers = 2;
+    private static final int numSlotsPerTaskManager = 2;
+
+    private int currentModelDataVersion;
+
+    private InMemorySourceFunction<DenseVector> trainSource;
+    private InMemorySourceFunction<DenseVector> predictSource;
+    private InMemorySinkFunction<Row> outputSink;
+    private InMemorySinkFunction<KMeansModelData> modelDataSink;
+
+    // TODO: creates static mini cluster once for whole test class after 
dependency upgrades to
+    // Flink 1.15.
+    private InMemoryReporter reporter;
+    private MiniCluster miniCluster;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    private Table offlineTrainTable;
+    private Table onlineTrainTable;
+    private Table onlinePredictTable;
+
+    @Before
+    public void before() throws Exception {
+        currentModelDataVersion = 0;
+
+        trainSource = new InMemorySourceFunction<>();
+        predictSource = new InMemorySourceFunction<>();
+        outputSink = new InMemorySinkFunction<>();
+        modelDataSink = new InMemorySinkFunction<>();
+
+        Configuration config = new Configuration();
+        config.set(RestOptions.BIND_PORT, "18081-19091");
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        reporter = InMemoryReporter.createWithRetainedMetrics();
+        reporter.addToConfiguration(config);
+
+        miniCluster =
+                new MiniCluster(
+                        new MiniClusterConfiguration.Builder()
+                                .setConfiguration(config)
+                                .setNumTaskManagers(numTaskManagers)
+                                
.setNumSlotsPerTaskManager(numSlotsPerTaskManager)
+                                .build());
+        miniCluster.start();
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(defaultParallelism);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();

Review comment:
       nits: `schema` here seems uncessary here or can we simply replace `f0` 
with `features` to simplify the code?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to