yunfengzhou-hub commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r829707954



##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/StreamingKMeansTest.java
##########
@@ -0,0 +1,527 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.ml.clustering.kmeans.KMeans;
+import org.apache.flink.ml.clustering.kmeans.KMeansModel;
+import org.apache.flink.ml.clustering.kmeans.KMeansModelData;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeans;
+import org.apache.flink.ml.clustering.kmeans.StreamingKMeansModel;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.util.MockKVStore;
+import org.apache.flink.ml.util.MockMessageQueues;
+import org.apache.flink.ml.util.MockSinkFunction;
+import org.apache.flink.ml.util.MockSourceFunction;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.ml.util.TestMetricReporter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.CollectionUtils;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static 
org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction;
+
+/** Tests {@link StreamingKMeans} and {@link StreamingKMeansModel}. */
+public class StreamingKMeansTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private static final DenseVector[] trainData1 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 0.0),
+                Vectors.dense(10.0, 0.3),
+                Vectors.dense(10.3, 0.0),
+                Vectors.dense(-10.0, 0.0),
+                Vectors.dense(-10.0, 0.6),
+                Vectors.dense(-10.6, 0.0)
+            };
+    private static final DenseVector[] trainData2 =
+            new DenseVector[] {
+                Vectors.dense(10.0, 100.0),
+                Vectors.dense(10.0, 100.3),
+                Vectors.dense(10.3, 100.0),
+                Vectors.dense(-10.0, -100.0),
+                Vectors.dense(-10.0, -100.6),
+                Vectors.dense(-10.6, -100.0)
+            };
+    private static final DenseVector[] predictData =
+            new DenseVector[] {
+                Vectors.dense(10.0, 10.0),
+                Vectors.dense(10.3, 10.0),
+                Vectors.dense(10.0, 10.3),
+                Vectors.dense(-10.0, 10.0),
+                Vectors.dense(-10.3, 10.0),
+                Vectors.dense(-10.0, 10.3)
+            };
+    private static final List<Set<DenseVector>> expectedGroups1 =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+    private static final List<Set<DenseVector>> expectedGroups2 =
+            Collections.singletonList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(10.0, 10.0),
+                                    Vectors.dense(10.3, 10.0),
+                                    Vectors.dense(10.0, 10.3),
+                                    Vectors.dense(-10.0, 10.0),
+                                    Vectors.dense(-10.3, 10.0),
+                                    Vectors.dense(-10.0, 10.3))));
+
+    private String trainId;
+    private String predictId;
+    private String outputId;
+    private String modelDataId;
+    private String metricReporterPrefix;
+    private String modelDataVersionGaugeKey;
+    private String currentModelDataVersion;
+    private List<String> queueIds;
+    private List<String> kvStoreKeys;
+    private List<JobClient> clients;
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table offlineTrainTable;
+    private Table trainTable;
+    private Table predictTable;
+
+    @Before
+    public void before() {
+        metricReporterPrefix = MockKVStore.createNonDuplicatePrefix();
+        kvStoreKeys = new ArrayList<>();
+        kvStoreKeys.add(metricReporterPrefix);
+
+        modelDataVersionGaugeKey =
+                TestMetricReporter.getKey(metricReporterPrefix, 
"modelDataVersion");
+
+        currentModelDataVersion = "0";
+
+        trainId = MockMessageQueues.createMessageQueue();
+        predictId = MockMessageQueues.createMessageQueue();
+        outputId = MockMessageQueues.createMessageQueue();
+        modelDataId = MockMessageQueues.createMessageQueue();
+        queueIds = new ArrayList<>();
+        queueIds.addAll(Arrays.asList(trainId, predictId, outputId, 
modelDataId));
+
+        clients = new ArrayList<>();
+
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        config.setString("metrics.reporters", "test_reporter");
+        config.setString(
+                "metrics.reporter.test_reporter.class", 
TestMetricReporter.class.getName());
+        config.setString("metrics.reporter.test_reporter.interval", "100 
MILLISECONDS");
+        config.setString("metrics.reporter.test_reporter.prefix", 
metricReporterPrefix);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema = Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+        offlineTrainTable =
+                tEnv.fromDataStream(env.fromElements(trainData1), 
schema).as("features");
+        trainTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(trainId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+        predictTable =
+                tEnv.fromDataStream(
+                                env.addSource(
+                                        new MockSourceFunction<>(predictId),
+                                        DenseVectorTypeInfo.INSTANCE),
+                                schema)
+                        .as("features");
+    }
+
+    @After
+    public void after() {
+        for (JobClient client : clients) {
+            try {
+                client.cancel();
+            } catch (IllegalStateException e) {
+                if (!e.getMessage()
+                        .equals("MiniCluster is not yet running or has already 
been shut down.")) {
+                    throw e;
+                }
+            }
+        }
+        clients.clear();
+
+        for (String queueId : queueIds) {
+            MockMessageQueues.deleteBlockingQueue(queueId);
+        }
+        queueIds.clear();
+
+        for (String key : kvStoreKeys) {
+            MockKVStore.remove(key);
+        }
+        kvStoreKeys.clear();
+    }
+
+    /** Adds sinks for StreamingKMeansModel's transform output and model data. 
*/
+    private void configModelSink(StreamingKMeansModel streamingModel) {
+        Table outputTable = streamingModel.transform(predictTable)[0];
+        DataStream<Row> output = tEnv.toDataStream(outputTable);
+        output.addSink(new MockSinkFunction<>(outputId));
+
+        Table modelDataTable = streamingModel.getModelData()[0];
+        DataStream<KMeansModelData> modelDataStream =
+                KMeansModelData.getModelDataStream(modelDataTable);
+        modelDataStream.addSink(new MockSinkFunction<>(modelDataId));
+    }
+
+    /** Blocks the thread until Model has set up init model data. */
+    private void waitInitModelDataSetup() {
+        while (!MockKVStore.containsKey(modelDataVersionGaugeKey)) {
+            Thread.yield();
+        }
+        waitModelDataUpdate();
+    }
+
+    /** Blocks the thread until the Model has received the next 
model-data-update event. */
+    private void waitModelDataUpdate() {
+        do {
+            String tmpModelDataVersion = 
MockKVStore.get(modelDataVersionGaugeKey);
+            if (tmpModelDataVersion.equals(currentModelDataVersion)) {
+                Thread.yield();
+            } else {
+                currentModelDataVersion = tmpModelDataVersion;
+                break;
+            }
+        } while (true);
+    }
+
+    /**
+     * Inserts default predict data to the predict queue, fetches the 
prediction results, and
+     * asserts that the grouping result is as expected.
+     *
+     * @param expectedGroups A list containing sets of features, which is the 
expected group result
+     * @param featuresCol Name of the column in the table that contains the 
features
+     * @param predictionCol Name of the column in the table that contains the 
prediction result
+     */
+    private void predictAndAssert(
+            List<Set<DenseVector>> expectedGroups, String featuresCol, String 
predictionCol)
+            throws Exception {
+        MockMessageQueues.offerAll(predictId, StreamingKMeansTest.predictData);
+        List<Row> rawResult =
+                MockMessageQueues.poll(outputId, 
StreamingKMeansTest.predictData.length);
+        List<Set<DenseVector>> actualGroups =
+                groupFeaturesByPrediction(rawResult, featuresCol, 
predictionCol);
+        Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
+    }
+
+    @Test
+    public void testParam() {
+        StreamingKMeans streamingKMeans = new StreamingKMeans();
+        Assert.assertEquals("features", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("prediction", streamingKMeans.getPredictionCol());
+        Assert.assertEquals(EuclideanDistanceMeasure.NAME, 
streamingKMeans.getDistanceMeasure());
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(2, streamingKMeans.getK());
+        Assert.assertEquals(1, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(1, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0., streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("random", streamingKMeans.getInitMode());
+        Assert.assertEquals(StreamingKMeans.class.getName().hashCode(), 
streamingKMeans.getSeed());
+
+        streamingKMeans
+                .setK(9)
+                .setFeaturesCol("test_feature")
+                .setPredictionCol("test_prediction")
+                .setK(3)
+                .setDims(5)
+                .setBatchSize(5)
+                .setDecayFactor(0.25)
+                .setInitMode("direct")
+                .setSeed(100);
+
+        Assert.assertEquals("test_feature", streamingKMeans.getFeaturesCol());
+        Assert.assertEquals("test_prediction", 
streamingKMeans.getPredictionCol());
+        Assert.assertEquals(3, streamingKMeans.getK());
+        Assert.assertEquals(5, streamingKMeans.getDims());
+        Assert.assertEquals("count", streamingKMeans.getBatchStrategy());
+        Assert.assertEquals(5, streamingKMeans.getBatchSize());
+        Assert.assertEquals(0.25, streamingKMeans.getDecayFactor(), 1e-5);
+        Assert.assertEquals("direct", streamingKMeans.getInitMode());
+        Assert.assertEquals(100, streamingKMeans.getSeed());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testInitWithOfflineKMeans() throws Exception {
+        KMeans offlineKMeans =
+                new 
KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setInitMode("direct")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setBatchSize(6);
+        ReadWriteUtils.updateExistingParams(streamingKMeans, 
offlineKMeans.getParamMap());
+
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups2,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testDecayFactor() throws Exception {
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans()
+                        .setInitMode("random")
+                        .setDims(2)
+                        .setInitWeights(new Double[] {0., 0.})
+                        .setDecayFactor(100.0)
+                        .setBatchSize(6)
+                        .setFeaturesCol("features")
+                        .setPredictionCol("prediction");
+        StreamingKMeansModel streamingModel = streamingKMeans.fit(trainTable);
+        configModelSink(streamingModel);
+
+        clients.add(env.executeAsync());
+        waitInitModelDataSetup();
+
+        MockMessageQueues.offerAll(trainId, trainData1);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+
+        MockMessageQueues.offerAll(trainId, trainData2);
+        waitModelDataUpdate();
+        predictAndAssert(
+                expectedGroups1,
+                streamingKMeans.getFeaturesCol(),
+                streamingKMeans.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveAndReload() throws Exception {
+        KMeans offlineKMeans =
+                new 
KMeans().setFeaturesCol("features").setPredictionCol("prediction");
+        KMeansModel offlineModel = offlineKMeans.fit(offlineTrainTable);
+
+        StreamingKMeans streamingKMeans =
+                new StreamingKMeans(offlineModel.getModelData())
+                        .setDims(2)
+                        .setBatchSize(6)
+                        .setInitWeights(new Double[] {0., 0.});
+        ReadWriteUtils.updateExistingParams(streamingKMeans, 
offlineKMeans.getParamMap());
+
+        StreamingKMeans loadedKMeans =
+                StageTestUtils.saveAndReload(
+                        env, streamingKMeans, 
tempFolder.newFolder().getAbsolutePath());
+
+        StreamingKMeansModel streamingModel = loadedKMeans.fit(trainTable);
+
+        String modelDataPassId = MockMessageQueues.createMessageQueue();
+        queueIds.add(modelDataPassId);
+
+        String modelSavePath = tempFolder.newFolder().getAbsolutePath();
+        streamingModel.save(modelSavePath);
+        KMeansModelData.getModelDataStream(streamingModel.getModelData()[0])
+                .addSink(new MockSinkFunction<>(modelDataPassId));
+        clients.add(env.executeAsync());

Review comment:
       In current `Pipeline`/`Graph`'s `save()` method, I think we have 
implicitly assumed that calling `execute()` after `save()` should always 
unblock the process, while this is not true if online algorithms are involved. 
If this is not the case then using `executeAsync()` would not be an issue.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {

Review comment:
       `TestMetricReporter` is not singleton, as you can see the following in 
the test:
   ```java
   config.setString(
           "metrics.reporter.test_reporter.class", 
TestMetricReporter.class.getName());
   ```
   Flink will use this information to instantiate `TestMetricReporter` on its 
own, so it will not be singleton.
   
   But I agree that we can remove `MockKVStore` for now as there is only 
`TestMetricReporter` using it. The `TestMeticReporter` can store values in its 
own static variables and provide the static get method to acquire them from 
Flink clients in test cases.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockMessageQueues.java
##########
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+
+/** A class that manages global message queues used in unit tests. */
+@SuppressWarnings({"unchecked", "rawtypes"})
+public class MockMessageQueues {
+    private static final Map<String, BlockingQueue> queueMap = new HashMap<>();
+    private static long counter = 0;
+
+    public static synchronized String createMessageQueue() {
+        String id = String.valueOf(counter);
+        queueMap.put(id, new LinkedBlockingQueue<>());
+        counter++;
+        return id;
+    }
+
+    @SafeVarargs
+    public static <T> void offerAll(String id, T... values) throws 
InterruptedException {
+        for (T value : values) {
+            offer(id, value);
+        }
+    }
+
+    public static <T> void offer(String id, T value) throws 
InterruptedException {
+        offer(id, value, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> void offer(String id, T value, long timeout, TimeUnit 
unit)
+            throws InterruptedException {
+        boolean success = queueMap.get(id).offer(value, timeout, unit);
+        if (!success) {
+            throw new RuntimeException(
+                    "Failed to offer " + value + " to blocking queue " + id + 
".");
+        }
+    }
+
+    public static <T> List<T> poll(String id, int num) throws 
InterruptedException {
+        List<T> result = new ArrayList<>();
+        for (int i = 0; i < num; i++) {
+            result.add(poll(id));
+        }
+        return result;
+    }
+
+    public static <T> T poll(String id) throws InterruptedException {
+        return poll(id, 1, TimeUnit.MINUTES);
+    }
+
+    public static <T> T poll(String id, long timeout, TimeUnit unit) throws 
InterruptedException {
+        T value = (T) queueMap.get(id).poll(timeout, unit);
+        if (value == null) {
+            throw new RuntimeException("Failed to poll next value from 
blocking queue " + id + ".");
+        }
+        return value;
+    }
+
+    public static void deleteBlockingQueue(String id) {

Review comment:
       In cases when there are multiple test classes running in parallel, I'm 
afraid that calling a `clear()` method in one test case would affect the 
process of the others. Thus I would prefer to have each test class specifies 
and deletes its own queues.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);

Review comment:
       I think it is hard to achieve for now. We need to create mini batches of 
fixed batch size from train data, but if the parallelism is larger than 1, we 
do not have a mechanism to count the total number of records received by each 
subtask.
   
   One possible solution I have wanted to propose is to insert barrier into the 
train data stream, so that even if train data would be distributed on different 
subtasks, the subtasks still knows when to finish the current batch so long as 
it can receive barrier. We have not got a change to discuss this problem and 
possible solutions offline.
   
   For now I prefer to still have this limit in this PR. I'll add relevant 
notices in its Javadoc.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), 
getK(), getSeed());
+        } else {
+            initModelDataStream = 
KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> 
initModelDataWithWeightsStream =
+                initModelDataStream.map(new 
InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                
DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new 
StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, 
DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData 
modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody 
{
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> 
modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> 
newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            
TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new 
UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, 
KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), 
DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in 
a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, 
DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double 
decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", 
DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) 
throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, 
DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();

Review comment:
       The behavior of OnlineKMeans as I have conceived is as follows. In each 
iteration, the algorithm should consume one batch of train data, and one set of 
model data received from feedback edge. If there are multiple batches of data 
waiting to be consumed, the operator would still consume one batch at a time. 
If there are zero batches of data when model data records are received, the 
operator would just cache the model data, waiting for train data to come in.
   
   So when model data comes in, it does not consume input points received 
before it. it just consumed the next batch of data, which might or might not 
have arrived yet.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), 
getK(), getSeed());
+        } else {
+            initModelDataStream = 
KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> 
initModelDataWithWeightsStream =
+                initModelDataStream.map(new 
InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                
DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new 
StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, 
DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData 
modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody 
{
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> 
modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> 
newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            
TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new 
UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, 
KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), 
DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in 
a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, 
DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double 
decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", 
DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) 
throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, 
DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + modelDataList.size()
+                                + " list of model data in this round");
+            }
+            DenseVector[] centroids = modelDataList.get(0).centroids;
+            modelDataState.clear();
+
+            List<DenseVector> weightsList = 
IteratorUtils.toList(weightsState.get().iterator());
+            if (weightsList.size() != 1) {
+                throw new RuntimeException(
+                        "The operator received "
+                                + weightsList.size()
+                                + " list of weights in this round");
+            }
+            DenseVector weights = weightsList.get(0);
+            weightsState.clear();
+
+            List<DenseVector[]> pointsList = 
IteratorUtils.toList(miniBatchState.get().iterator());
+            DenseVector[] points = pointsList.get(0);
+            pointsList.remove(0);
+            miniBatchState.clear();
+            miniBatchState.addAll(pointsList);

Review comment:
       As described above, in online algorithm I suppose the training process 
should produce one model data update for each batch of train data. If there are 
multiple batches of train data, they should not be merged into one. If there 
are zero batches of train data, the iteration body would not continue training 
on an empty batch either.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model

Review comment:
       I agree. I'll reference to Spark's documentation and add similar 
JavaDocs here.
   
   Besides, Spark's JavaDoc is mainly on `StreamingKMeansModel`, rather than 
`StreamingKMeans`. I think it should be `StreamingKMeans` that is explained 
more in detail, and similar to Spark's JavaDoc, my added documentation would be 
mainly about the training process, so I'll add the detailed documentations on 
`OnlineKMeans`.

##########
File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/MockKVStore.java
##########
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Class that manages global key-value pairs used in unit tests. */
+@SuppressWarnings({"unchecked"})
+public class MockKVStore {
+    private static final Map<String, Object> map = new HashMap<>();
+
+    private static long counter = 0;
+
+    /**
+     * Returns a prefix string to make sure key-value pairs created in 
different test cases would
+     * not have the same key.
+     */
+    public static synchronized String createNonDuplicatePrefix() {

Review comment:
       I agree that it would be better to remove this method. But in this case 
a random int64 value might bring risk, as test cases running in parallel might 
generate identical keys. I'll try to avoid this problem by using a more complex 
key pattern, like composing class name and a monotonically increasing int64.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDecayFactor.java
##########
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared decay factor param. */
+public interface HasDecayFactor<T> extends WithParams<T> {
+    Param<Double> DECAY_FACTOR =
+            new DoubleParam(
+                    "decayFactor",
+                    "The forgetfulness of the previous centroids.",
+                    0.,
+                    ParamValidators.gtEq(0));

Review comment:
       Spark has not set an upper bound limit to this parameter. It only 
ensures that decayFactor is nonnegative, as follows.
   ```scala
     def setDecayFactor(a: Double): this.type = {
       require(a >= 0,
         s"Decay factor must be nonnegative but got ${a}")
       this.decayFactor = a
       this
     }
   ```
   I also understand your concern. If this value is larger than 1, then it 
cannot represent so-called "forgetfulness", as the weight of init model data is 
always strengthened... Maybe we can make it have to be smaller than 1, and 
remove this limit when we found relative use cases. What do you think?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), 
getK(), getSeed());

Review comment:
       I agree. In fact we only need to check `initModelDataTable == null` when 
initMode is `random`, as if initMode is `direct` the following implementation 
would naturally fail unless `initModelDataTable != null`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), 
getK(), getSeed());
+        } else {
+            initModelDataStream = 
KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> 
initModelDataWithWeightsStream =
+                initModelDataStream.map(new 
InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                
DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);

Review comment:
       Not with the current implementation. I tried making the output of the 
iteration body to be the input feedback stream, and this modification can 
guarantee this.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeansModel.java
##########
@@ -0,0 +1,176 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.co.CoProcessFunction;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * StreamingKMeansModel can be regarded as an advanced {@link KMeansModel} 
operator which can update
+ * model data in a streaming format, using the model data provided by {@link 
StreamingKMeans}.
+ */
+public class StreamingKMeansModel
+        implements Model<StreamingKMeansModel>, 
KMeansModelParams<StreamingKMeansModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public StreamingKMeansModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public StreamingKMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> predictionResult =
+                KMeansModelData.getModelDataStream(modelDataTable)
+                        .broadcast()
+                        .connect(tEnv.toDataStream(inputs[0]))
+                        .process(
+                                new PredictLabelFunction(
+                                        getFeaturesCol(),
+                                        
DistanceMeasure.getInstance(getDistanceMeasure())),
+                                outputTypeInfo);
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility function used for prediction. */
+    private static class PredictLabelFunction extends 
CoProcessFunction<KMeansModelData, Row, Row> {
+        private final String featuresCol;
+
+        private final DistanceMeasure distanceMeasure;
+
+        private DenseVector[] centroids;
+
+        // TODO: replace this with a complete solution of reading first model 
data from unbounded
+        // model data stream before processing the first predict data.
+        private final List<Row> cache = new ArrayList<>();

Review comment:
       OK. I'll make the change.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
##########
@@ -207,25 +211,21 @@ public void testSaveLoadAndPredict() throws Exception {
         KMeansModel loadedModel =
                 StageTestUtils.saveAndReload(env, model, 
tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
-        assertEquals(
-                Collections.singletonList("centroids"),
-                
loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
 
+        List<Row> results = IteratorUtils.toList(output.execute().collect());
         List<Set<DenseVector>> actualGroups =
-                executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
+                groupFeaturesByPrediction(
+                        results, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
         assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
     }
 
     @Test
     public void testGetModelData() throws Exception {
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
         KMeansModel model = kmeans.fit(dataTable);
-        assertEquals(
-                Collections.singletonList("centroids"),

Review comment:
       I had once changed the design of `KMeansModelData` to make it contain 
more than centroids (adding weights field), and this removal is the result of 
that change. Now I have recovered `KMeansModelData`'s structure, but forgot to 
add back this check. I'll fix it.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/StreamingKMeans.java
##########
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+/**
+ * StreamingKMeans extends the function of {@link KMeans}, supporting to train 
a K-Means model
+ * continuously according to an unbounded stream of train data.
+ */
+public class StreamingKMeans
+        implements Estimator<StreamingKMeans, StreamingKMeansModel>,
+                StreamingKMeansParams<StreamingKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public StreamingKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public StreamingKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public StreamingKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) 
tEnv).execEnv();
+
+        DataStream<DenseVector> points =
+                tEnv.toDataStream(inputs[0]).map(new 
FeaturesExtractor(getFeaturesCol()));
+        points.getTransformation().setParallelism(1);
+
+        DataStream<KMeansModelData> initModelDataStream;
+        if (getInitMode().equals("random")) {
+            initModelDataStream = createRandomCentroids(env, getDims(), 
getK(), getSeed());
+        } else {
+            initModelDataStream = 
KMeansModelData.getModelDataStream(initModelDataTable);
+        }
+        DataStream<Tuple2<KMeansModelData, DenseVector>> 
initModelDataWithWeightsStream =
+                initModelDataStream.map(new 
InitWeightAssigner(getInitWeights()));
+        initModelDataWithWeightsStream.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new StreamingKMeansIterationBody(
+                        DistanceMeasure.getInstance(getDistanceMeasure()),
+                        getDecayFactor(),
+                        getBatchSize(),
+                        getK());
+
+        DataStream<KMeansModelData> finalModelDataStream =
+                Iterations.iterateUnboundedStreams(
+                                
DataStreamList.of(initModelDataWithWeightsStream),
+                                DataStreamList.of(points),
+                                body)
+                        .get(0);
+        finalModelDataStream = finalModelDataStream.union(initModelDataStream);
+
+        Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream);
+        StreamingKMeansModel model = new 
StreamingKMeansModel().setModelData(finalModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    private static class InitWeightAssigner
+            implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, 
DenseVector>> {
+        private final double[] initWeights;
+
+        private InitWeightAssigner(Double[] initWeights) {
+            this.initWeights = ArrayUtils.toPrimitive(initWeights);
+        }
+
+        @Override
+        public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData 
modelData)
+                throws Exception {
+            return Tuple2.of(modelData, Vectors.dense(initWeights));
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        if (initModelDataTable != null) {
+            ReadWriteUtils.saveModelData(
+                    KMeansModelData.getModelDataStream(initModelDataTable),
+                    path,
+                    new KMeansModelData.ModelDataEncoder());
+        }
+
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static StreamingKMeans load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        StreamingKMeans kMeans = ReadWriteUtils.loadStageParam(path);
+
+        Path initModelDataPath = Paths.get(path, "data");
+        if (Files.exists(initModelDataPath)) {
+            StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+            DataStream<KMeansModelData> initModelDataStream =
+                    ReadWriteUtils.loadModelData(env, path, new 
KMeansModelData.ModelDataDecoder());
+
+            kMeans.initModelDataTable = 
tEnv.fromDataStream(initModelDataStream);
+            kMeans.setInitMode("direct");
+        }
+
+        return kMeans;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class StreamingKMeansIterationBody implements IterationBody 
{
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int batchSize;
+        private final int k;
+
+        public StreamingKMeansIterationBody(
+                DistanceMeasure distanceMeasure, double decayFactor, int 
batchSize, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.batchSize = batchSize;
+            this.k = k;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<Tuple2<KMeansModelData, DenseVector>> 
modelDataWithWeights =
+                    variableStreams.get(0);
+            DataStream<DenseVector> points = dataStreams.get(0);
+
+            DataStream<Tuple2<KMeansModelData, DenseVector>> 
newModelDataWithWeights =
+                    points.countWindowAll(batchSize)
+                            .aggregate(new MiniBatchCreator())
+                            .connect(modelDataWithWeights.broadcast())
+                            .transform(
+                                    "UpdateModelData",
+                                    new TupleTypeInfo<>(
+                                            
TypeInformation.of(KMeansModelData.class),
+                                            DenseVectorTypeInfo.INSTANCE),
+                                    new 
UpdateModelDataOperator(distanceMeasure, decayFactor, k))
+                            .setParallelism(1);
+
+            DataStream<KMeansModelData> newModelData =
+                    newModelDataWithWeights.map(
+                            (MapFunction<Tuple2<KMeansModelData, DenseVector>, 
KMeansModelData>)
+                                    x -> x.f0);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelDataWithWeights), 
DataStreamList.of(newModelData));
+        }
+    }
+
+    // TODO: change this single-threaded implementation to support training in 
a distributed way,
+    // after model data
+    // version mechanism is implemented.
+    private static class UpdateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<KMeansModelData, 
DenseVector>>
+            implements TwoInputStreamOperator<
+                    DenseVector[],
+                    Tuple2<KMeansModelData, DenseVector>,
+                    Tuple2<KMeansModelData, DenseVector>> {
+        private final DistanceMeasure distanceMeasure;
+        private final double decayFactor;
+        private final int k;
+        private ListState<DenseVector[]> miniBatchState;
+        private ListState<KMeansModelData> modelDataState;
+        private ListState<DenseVector> weightsState;
+
+        public UpdateModelDataOperator(DistanceMeasure distanceMeasure, double 
decayFactor, int k) {
+            this.distanceMeasure = distanceMeasure;
+            this.decayFactor = decayFactor;
+            this.k = k;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<DenseVector[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            miniBatchState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("miniBatch", type));
+
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
KMeansModelData.class));
+
+            weightsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "weights", 
DenseVectorTypeInfo.INSTANCE));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<DenseVector[]> streamRecord) 
throws Exception {
+            miniBatchState.add(streamRecord.getValue());
+            processElement();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<Tuple2<KMeansModelData, 
DenseVector>> streamRecord)
+                throws Exception {
+            modelDataState.add(streamRecord.getValue().f0);
+            weightsState.add(streamRecord.getValue().f1);
+            processElement();
+        }
+
+        private void processElement() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !miniBatchState.get().iterator().hasNext()) {
+                return;
+            }
+
+            // Retrieves data from states.
+            List<KMeansModelData> modelDataList =
+                    IteratorUtils.toList(modelDataState.get().iterator());
+            if (modelDataList.size() != 1) {

Review comment:
       Got it. I'll also apply it to `KMeans`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to