lindong28 commented on a change in pull request #70: URL: https://github.com/apache/flink-ml/pull/70#discussion_r832171267
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { Review comment: nits: it is not clear what `config` in the method name refers to. How about renaming it as `transformAndOutputData(...)`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java ########## @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link SourceFunction} implementation that can directly receive records from tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySourceFunction<T> extends RichSourceFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<T> queue; + private boolean isRunning = true; + + public InMemorySourceFunction() { + id = UUID.randomUUID(); + queue = new LinkedBlockingQueue(); + queueMap.put(id, queue); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + queue = queueMap.get(id); + } + + @Override + public void close() throws Exception { + super.close(); + queueMap.remove(id); + } + + @Override + public void run(SourceContext<T> context) { + while (isRunning) { + T value = queue.poll(); + if (value == null) { + Thread.yield(); Review comment: Would it be more performant and simpler to use `queue.take()` to get values from the queue? This is a blocking call which reduce the chance of busy loop. And if we agree to do this, we need to make sure the `InMemorySourceFunction` will not block on this queue forever when the test finishes. One approach is to add a dummy value to the queue when `cancel()` is invoked. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; Review comment: nits: Would it be a bit simpler to do the following: ``` Table outputTable = onlineModel.transform(predictTable)[0]; tEnv.toDataStream(outputTable).addSink(outputSink); Table modelDataTable = onlineModel.getModelData()[0]; KMeansModelData.getModelDataStream(modelDataTable).addSink(modelDataSink); ``` ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + DataStream<Row> output = tEnv.toDataStream(outputTable); + output.addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + DataStream<KMeansModelData> modelDataStream = + KMeansModelData.getModelDataStream(modelDataTable); + modelDataStream.addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.yield(); + } + waitModelDataUpdate(); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate() { + do { + String tmpModelDataVersion = + String.valueOf( + reporter.findMetrics("modelDataVersion").values().stream() + .map(x -> Integer.parseInt(((Gauge<String>) x).getValue())) + .min(Integer::compareTo) + .orElse(Integer.parseInt(currentModelDataVersion))); + if (tmpModelDataVersion.equals(currentModelDataVersion)) { + Thread.yield(); + } else { + currentModelDataVersion = tmpModelDataVersion; + break; + } + } while (true); + } + + /** + * Inserts default predict data to the predict queue, fetches the prediction results, and + * asserts that the grouping result is as expected. + * + * @param expectedGroups A list containing sets of features, which is the expected group result + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + */ + private void predictAndAssert( + List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol) + throws Exception { + predictSource.offerAll(OnlineKMeansTest.predictData); + List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); + List<Set<DenseVector>> actualGroups = + groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); + Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testParam() { + OnlineKMeans onlineKMeans = new OnlineKMeans(); + Assert.assertEquals("features", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure()); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(2, onlineKMeans.getK()); + Assert.assertEquals(1, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(1, onlineKMeans.getBatchSize()); + Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed()); + + onlineKMeans + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setK(3) + .setDims(5) + .setBatchSize(5) + .setDecayFactor(0.25) + .setInitMode("direct") + .setSeed(100); + + Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(3, onlineKMeans.getK()); + Assert.assertEquals(5, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(5, onlineKMeans.getBatchSize()); + Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("direct", onlineKMeans.getInitMode()); + Assert.assertEquals(100, onlineKMeans.getSeed()); + } + + @Test + public void testFitAndPredict() throws Exception { + OnlineKMeans onlineKMeans = + new OnlineKMeans() + .setInitMode("random") + .setDims(2) + .setInitWeights(new Double[] {0., 0.}) + .setBatchSize(6) + .setFeaturesCol("features") + .setPredictionCol("prediction"); + OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable); + configTransformAndSink(onlineModel); + + miniCluster.submitJob(env.getStreamGraph().getJobGraph()); + waitInitModelDataSetup(); + + trainSource.offerAll(trainData1); + waitModelDataUpdate(); + predictAndAssert( + expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + + trainSource.offerAll(trainData2); + waitModelDataUpdate(); + predictAndAssert( + expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + } + + @Test + public void testInitWithKMeans() throws Exception { + KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction"); + KMeansModel kMeansModel = kMeans.fit(offlineTrainTable); + + OnlineKMeans onlineKMeans = + new OnlineKMeans(kMeansModel.getModelData()) + .setFeaturesCol("features") + .setPredictionCol("prediction") + .setInitMode("direct") + .setDims(2) + .setInitWeights(new Double[] {0., 0.}) + .setBatchSize(6); + + OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable); + configTransformAndSink(onlineModel); + + miniCluster.submitJob(env.getStreamGraph().getJobGraph()); + waitInitModelDataSetup(); + predictAndAssert( + expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + + trainSource.offerAll(trainData2); + waitModelDataUpdate(); + predictAndAssert( + expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + } + + @Test + public void testDecayFactor() throws Exception { + KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction"); + KMeansModel kMeansModel = kMeans.fit(offlineTrainTable); + + OnlineKMeans onlineKMeans = + new OnlineKMeans(kMeansModel.getModelData()) + .setDims(2) + .setInitWeights(new Double[] {3., 3.}) + .setDecayFactor(0.5) + .setBatchSize(6) + .setFeaturesCol("features") + .setPredictionCol("prediction"); + OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable); + configTransformAndSink(onlineModel); + + miniCluster.submitJob(env.getStreamGraph().getJobGraph()); + modelDataSink.poll(); + + trainSource.offerAll(trainData2); + KMeansModelData actualModelData = modelDataSink.poll(); + + KMeansModelData expectedModelData = + new KMeansModelData( + new DenseVector[] { + Vectors.dense(10.1, 200.3 / 3), Vectors.dense(-10.2, -200.2 / 3) + }); + + Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length); + Arrays.sort(actualModelData.centroids, (o1, o2) -> (int) (o2.values[0] - o1.values[0])); Review comment: Would it be simpler to use `(o1, o2) -> Doubles.compare(o1.values[0], o2.values[0])`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + DataStream<Row> output = tEnv.toDataStream(outputTable); + output.addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + DataStream<KMeansModelData> modelDataStream = + KMeansModelData.getModelDataStream(modelDataTable); + modelDataStream.addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.yield(); + } + waitModelDataUpdate(); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate() { + do { + String tmpModelDataVersion = + String.valueOf( + reporter.findMetrics("modelDataVersion").values().stream() + .map(x -> Integer.parseInt(((Gauge<String>) x).getValue())) + .min(Integer::compareTo) + .orElse(Integer.parseInt(currentModelDataVersion))); + if (tmpModelDataVersion.equals(currentModelDataVersion)) { + Thread.yield(); + } else { + currentModelDataVersion = tmpModelDataVersion; + break; + } + } while (true); + } + + /** + * Inserts default predict data to the predict queue, fetches the prediction results, and + * asserts that the grouping result is as expected. + * + * @param expectedGroups A list containing sets of features, which is the expected group result + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + */ + private void predictAndAssert( + List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol) + throws Exception { + predictSource.offerAll(OnlineKMeansTest.predictData); + List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); + List<Set<DenseVector>> actualGroups = + groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); + Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testParam() { + OnlineKMeans onlineKMeans = new OnlineKMeans(); + Assert.assertEquals("features", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure()); + Assert.assertEquals("random", onlineKMeans.getInitMode()); Review comment: `onlineKMeans.getInitMode()` is checked twice. Should we remove this one? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + DataStream<Row> output = tEnv.toDataStream(outputTable); + output.addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + DataStream<KMeansModelData> modelDataStream = + KMeansModelData.getModelDataStream(modelDataTable); + modelDataStream.addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.yield(); Review comment: `Thread.yield()` could lead to busy loop and waste CPU cycles. How about we use `Thread.sleep(100)`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + DataStream<Row> output = tEnv.toDataStream(outputTable); + output.addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + DataStream<KMeansModelData> modelDataStream = + KMeansModelData.getModelDataStream(modelDataTable); + modelDataStream.addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.yield(); + } + waitModelDataUpdate(); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate() { + do { + String tmpModelDataVersion = + String.valueOf( + reporter.findMetrics("modelDataVersion").values().stream() + .map(x -> Integer.parseInt(((Gauge<String>) x).getValue())) + .min(Integer::compareTo) + .orElse(Integer.parseInt(currentModelDataVersion))); + if (tmpModelDataVersion.equals(currentModelDataVersion)) { + Thread.yield(); + } else { + currentModelDataVersion = tmpModelDataVersion; + break; + } + } while (true); + } + + /** + * Inserts default predict data to the predict queue, fetches the prediction results, and + * asserts that the grouping result is as expected. + * + * @param expectedGroups A list containing sets of features, which is the expected group result + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + */ + private void predictAndAssert( + List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol) + throws Exception { + predictSource.offerAll(OnlineKMeansTest.predictData); + List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); + List<Set<DenseVector>> actualGroups = + groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); + Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testParam() { + OnlineKMeans onlineKMeans = new OnlineKMeans(); + Assert.assertEquals("features", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure()); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(2, onlineKMeans.getK()); + Assert.assertEquals(1, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(1, onlineKMeans.getBatchSize()); + Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed()); + + onlineKMeans + .setK(9) Review comment: Should we remove this one? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + DataStream<Row> output = tEnv.toDataStream(outputTable); + output.addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + DataStream<KMeansModelData> modelDataStream = + KMeansModelData.getModelDataStream(modelDataTable); + modelDataStream.addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.yield(); + } + waitModelDataUpdate(); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate() { + do { + String tmpModelDataVersion = + String.valueOf( + reporter.findMetrics("modelDataVersion").values().stream() + .map(x -> Integer.parseInt(((Gauge<String>) x).getValue())) + .min(Integer::compareTo) + .orElse(Integer.parseInt(currentModelDataVersion))); + if (tmpModelDataVersion.equals(currentModelDataVersion)) { + Thread.yield(); + } else { + currentModelDataVersion = tmpModelDataVersion; + break; + } + } while (true); + } + + /** + * Inserts default predict data to the predict queue, fetches the prediction results, and + * asserts that the grouping result is as expected. + * + * @param expectedGroups A list containing sets of features, which is the expected group result + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + */ + private void predictAndAssert( + List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol) + throws Exception { + predictSource.offerAll(OnlineKMeansTest.predictData); + List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); + List<Set<DenseVector>> actualGroups = + groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); + Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testParam() { + OnlineKMeans onlineKMeans = new OnlineKMeans(); + Assert.assertEquals("features", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure()); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(2, onlineKMeans.getK()); + Assert.assertEquals(1, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(1, onlineKMeans.getBatchSize()); + Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed()); + + onlineKMeans + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setK(3) + .setDims(5) + .setBatchSize(5) + .setDecayFactor(0.25) + .setInitMode("direct") + .setSeed(100); + + Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(3, onlineKMeans.getK()); + Assert.assertEquals(5, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(5, onlineKMeans.getBatchSize()); + Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("direct", onlineKMeans.getInitMode()); + Assert.assertEquals(100, onlineKMeans.getSeed()); + } + + @Test + public void testFitAndPredict() throws Exception { + OnlineKMeans onlineKMeans = + new OnlineKMeans() + .setInitMode("random") Review comment: nits: could we use the same order of setXXX(...) across unit tests? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.common.param.HasBatchStrategy; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + * + * <p>NOTE: This class's current naive implementation performs the training process in a + * single-threaded way. Correctness is not affected but there are performance issues. Review comment: Given that the current implementation supports parallelization, could we remove this statement now? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySinkFunction.java ########## @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link SinkFunction} implementation that makes all collected records available for tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySinkFunction<T> extends RichSinkFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<T> queue; + + public InMemorySinkFunction() { + id = UUID.randomUUID(); + queue = new LinkedBlockingQueue(); + queueMap.put(id, queue); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + queue = queueMap.get(id); + } + + @Override + public void close() throws Exception { + super.close(); + queueMap.remove(id); + } + + @Override + public void invoke(T value, Context context) { + if (!queue.offer(value)) { + throw new RuntimeException( + "Failed to offer " + value + " to blocking queue " + id + "."); + } + } + + public List<T> poll(int num) throws InterruptedException { Review comment: `poll(...)` typically have a timeout. And it allows returning a null value if there is not sufficient values in the queue before the timeout. This method currently requires that at least the expected number of values are in the queue, and throw exception otherwise. This semantic seems closer to take(). How about renaming this method and the method below to `take(...)`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,569 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.api.java.typeutils.TupleTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.common.param.HasBatchStrategy; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Random; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If decay + * factor is 1, all batches are weighted equally. If decay factor is 0, new centroids are determined + * entirely by recent data. Lower values correspond to more forgetting. + * + * <p>NOTE: This class's current naive implementation performs the training process in a + * single-threaded way. Correctness is not affected but there are performance issues. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + public OnlineKMeans(Table... initModelDataTables) { + Preconditions.checkArgument(initModelDataTables.length == 1); + this.initModelDataTable = initModelDataTables[0]; + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + setInitMode("direct"); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy())); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + StreamExecutionEnvironment env = ((StreamTableEnvironmentImpl) tEnv).execEnv(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelDataStream; + if (getInitMode().equals("random")) { + Preconditions.checkState(initModelDataTable == null); + initModelDataStream = createRandomCentroids(env, getDims(), getK(), getSeed()); + } else { + initModelDataStream = KMeansModelData.getModelDataStream(initModelDataTable); + } + DataStream<Tuple2<KMeansModelData, DenseVector>> initModelDataWithWeightsStream = + initModelDataStream.map(new InitWeightAssigner(getInitWeights())); + initModelDataWithWeightsStream.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getDecayFactor(), + getBatchSize(), + getK(), + getDims()); + + DataStream<KMeansModelData> finalModelDataStream = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelDataWithWeightsStream), + DataStreamList.of(points), + body) + .get(0); + + Table finalModelDataTable = tEnv.fromDataStream(finalModelDataStream); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(finalModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + private static class InitWeightAssigner + implements MapFunction<KMeansModelData, Tuple2<KMeansModelData, DenseVector>> { + private final double[] initWeights; + + private InitWeightAssigner(Double[] initWeights) { + this.initWeights = ArrayUtils.toPrimitive(initWeights); + } + + @Override + public Tuple2<KMeansModelData, DenseVector> map(KMeansModelData modelData) + throws Exception { + return Tuple2.of(modelData, Vectors.dense(initWeights)); + } + } + + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans kMeans = ReadWriteUtils.loadStageParam(path); + + Path initModelDataPath = Paths.get(path, "data"); + if (Files.exists(initModelDataPath)) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + kMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + kMeans.setInitMode("direct"); + } + + return kMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final double decayFactor; + private final int batchSize; + private final int k; + private final int dims; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, + double decayFactor, + int batchSize, + int k, + int dims) { + this.distanceMeasure = distanceMeasure; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + this.k = k; + this.dims = dims; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<Tuple2<KMeansModelData, DenseVector>> modelDataWithWeights = + variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<Tuple2<KMeansModelData, DenseVector>> newModelDataWithWeights = + points.countWindowAll(batchSize) + .aggregate(new MiniBatchCreator()) + .flatMap(new MiniBatchDistributor(parallelism)) + .rebalance() + .connect(modelDataWithWeights.broadcast()) + .transform( + "ModelDataPartialUpdater", + new TupleTypeInfo<>( + TypeInformation.of(KMeansModelData.class), + DenseVectorTypeInfo.INSTANCE), + new ModelDataPartialUpdater(distanceMeasure, k)) + .setParallelism(parallelism) + .connect(modelDataWithWeights.broadcast()) + .transform( + "ModelDataGlobalUpdater", + new TupleTypeInfo<>( + TypeInformation.of(KMeansModelData.class), + DenseVectorTypeInfo.INSTANCE), + new ModelDataGlobalUpdater(k, dims, parallelism, decayFactor)) + .setParallelism(1); + + DataStream<KMeansModelData> outputModelData = + modelDataWithWeights.map( + (MapFunction<Tuple2<KMeansModelData, DenseVector>, KMeansModelData>) + x -> x.f0); + + return new IterationBodyResult( + DataStreamList.of(newModelDataWithWeights), DataStreamList.of(outputModelData)); + } + } + + private static class ModelDataGlobalUpdater + extends AbstractStreamOperator<Tuple2<KMeansModelData, DenseVector>> + implements TwoInputStreamOperator< + Tuple2<KMeansModelData, DenseVector>, + Tuple2<KMeansModelData, DenseVector>, + Tuple2<KMeansModelData, DenseVector>> { + private final int k; + private final int dims; + private final int upstreamParallelism; + private final double decayFactor; + + private ListState<Integer> partialModelDataReceivingState; + private ListState<Boolean> initModelDataReceivingState; + private ListState<KMeansModelData> modelDataState; + private ListState<DenseVector> weightsState; + + private ModelDataGlobalUpdater( + int k, int dims, int upstreamParallelism, double decayFactor) { + this.k = k; + this.dims = dims; + this.upstreamParallelism = upstreamParallelism; + this.decayFactor = decayFactor; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + partialModelDataReceivingState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "partialModelDataReceiving", + TypeInformation.of(Integer.class))); Review comment: My understanding is that for basic types like int, we use `BasicTypeInfo.INT_TYPE_INFO`. Not sure if `TypeInformation.of(Integer.class)` would have inferior performance. It seems simpler to just use `BasicTypeInfo.INT_TYPE_INFO`. Alternatively, please feel free to ask Yun Gao what is the recommended approach. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasBatchStrategy.java ########## @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared batch strategy param. */ +@SuppressWarnings("unchecked") +public interface HasBatchStrategy<T> extends WithParams<T> { + String COUNT_STRATEGY = "count"; + + Param<String> BATCH_STRATEGY = + new StringParam( + "batchStrategy", + "Strategy to create mini batch from online train data.", + COUNT_STRATEGY, + ParamValidators.inArray(COUNT_STRATEGY)); + + Param<Integer> BATCH_SIZE = Review comment: There is difference between global batch size and local batch size. `global_batch_size = local_batch_size * num_worker` which is well defined when multiple workers are used in synchronous training, which seems to be the case for `OnlineKMeans`. Instead of re-defining a global batch size parameter here, how about we re-use `HasGlobalBatchSize`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; Review comment: Would it be simpler to use `int currentModelDataVersion`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/OnlineKMeansTest.java ########## @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.configuration.RestOptions; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.clustering.kmeans.KMeans; +import org.apache.flink.ml.clustering.kmeans.KMeansModel; +import org.apache.flink.ml.clustering.kmeans.KMeansModelData; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeans; +import org.apache.flink.ml.clustering.kmeans.OnlineKMeansModel; +import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.util.InMemorySinkFunction; +import org.apache.flink.ml.util.InMemorySourceFunction; +import org.apache.flink.runtime.minicluster.MiniCluster; +import org.apache.flink.runtime.minicluster.MiniClusterConfiguration; +import org.apache.flink.runtime.testutils.InMemoryReporter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.CollectionUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.apache.flink.ml.clustering.KMeansTest.groupFeaturesByPrediction; + +/** Tests {@link OnlineKMeans} and {@link OnlineKMeansModel}. */ +public class OnlineKMeansTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private static final DenseVector[] trainData1 = + new DenseVector[] { + Vectors.dense(10.0, 0.0), + Vectors.dense(10.0, 0.3), + Vectors.dense(10.3, 0.0), + Vectors.dense(-10.0, 0.0), + Vectors.dense(-10.0, 0.6), + Vectors.dense(-10.6, 0.0) + }; + private static final DenseVector[] trainData2 = + new DenseVector[] { + Vectors.dense(10.0, 100.0), + Vectors.dense(10.0, 100.3), + Vectors.dense(10.3, 100.0), + Vectors.dense(-10.0, -100.0), + Vectors.dense(-10.0, -100.6), + Vectors.dense(-10.6, -100.0) + }; + private static final DenseVector[] predictData = + new DenseVector[] { + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3) + }; + private static final List<Set<DenseVector>> expectedGroups1 = + Arrays.asList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3))), + new HashSet<>( + Arrays.asList( + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + private static final List<Set<DenseVector>> expectedGroups2 = + Collections.singletonList( + new HashSet<>( + Arrays.asList( + Vectors.dense(10.0, 10.0), + Vectors.dense(10.3, 10.0), + Vectors.dense(10.0, 10.3), + Vectors.dense(-10.0, 10.0), + Vectors.dense(-10.3, 10.0), + Vectors.dense(-10.0, 10.3)))); + + private static final int defaultParallelism = 4; + private static final int numTaskManagers = 2; + private static final int numSlotsPerTaskManager = 2; + + private String currentModelDataVersion; + + private InMemorySourceFunction<DenseVector> trainSource; + private InMemorySourceFunction<DenseVector> predictSource; + private InMemorySinkFunction<Row> outputSink; + private InMemorySinkFunction<KMeansModelData> modelDataSink; + + private InMemoryReporter reporter; + private MiniCluster miniCluster; + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + private Table offlineTrainTable; + private Table trainTable; + private Table predictTable; + + @Before + public void before() throws Exception { + currentModelDataVersion = "0"; + + trainSource = new InMemorySourceFunction<>(); + predictSource = new InMemorySourceFunction<>(); + outputSink = new InMemorySinkFunction<>(); + modelDataSink = new InMemorySinkFunction<>(); + + Configuration config = new Configuration(); + config.set(RestOptions.BIND_PORT, "18081-19091"); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + reporter = InMemoryReporter.createWithRetainedMetrics(); + reporter.addToConfiguration(config); + + miniCluster = + new MiniCluster( + new MiniClusterConfiguration.Builder() + .setConfiguration(config) + .setNumTaskManagers(numTaskManagers) + .setNumSlotsPerTaskManager(numSlotsPerTaskManager) + .build()); + miniCluster.start(); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(defaultParallelism); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); + + offlineTrainTable = + tEnv.fromDataStream(env.fromElements(trainData1), schema).as("features"); + trainTable = + tEnv.fromDataStream( + env.addSource(trainSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + predictTable = + tEnv.fromDataStream( + env.addSource(predictSource, DenseVectorTypeInfo.INSTANCE), schema) + .as("features"); + } + + @After + public void after() throws Exception { + miniCluster.close(); + } + + /** + * Performs transform() on the provided model with predictTable, and adds sinks for + * OnlineKMeansModel's transform output and model data. + */ + private void configTransformAndSink(OnlineKMeansModel onlineModel) { + Table outputTable = onlineModel.transform(predictTable)[0]; + DataStream<Row> output = tEnv.toDataStream(outputTable); + output.addSink(outputSink); + + Table modelDataTable = onlineModel.getModelData()[0]; + DataStream<KMeansModelData> modelDataStream = + KMeansModelData.getModelDataStream(modelDataTable); + modelDataStream.addSink(modelDataSink); + } + + /** Blocks the thread until Model has set up init model data. */ + private void waitInitModelDataSetup() { + while (reporter.findMetrics("modelDataVersion").size() < defaultParallelism) { + Thread.yield(); + } + waitModelDataUpdate(); + } + + /** Blocks the thread until the Model has received the next model-data-update event. */ + @SuppressWarnings("unchecked") + private void waitModelDataUpdate() { + do { + String tmpModelDataVersion = + String.valueOf( + reporter.findMetrics("modelDataVersion").values().stream() + .map(x -> Integer.parseInt(((Gauge<String>) x).getValue())) + .min(Integer::compareTo) + .orElse(Integer.parseInt(currentModelDataVersion))); + if (tmpModelDataVersion.equals(currentModelDataVersion)) { + Thread.yield(); + } else { + currentModelDataVersion = tmpModelDataVersion; + break; + } + } while (true); + } + + /** + * Inserts default predict data to the predict queue, fetches the prediction results, and + * asserts that the grouping result is as expected. + * + * @param expectedGroups A list containing sets of features, which is the expected group result + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + */ + private void predictAndAssert( + List<Set<DenseVector>> expectedGroups, String featuresCol, String predictionCol) + throws Exception { + predictSource.offerAll(OnlineKMeansTest.predictData); + List<Row> rawResult = outputSink.poll(OnlineKMeansTest.predictData.length); + List<Set<DenseVector>> actualGroups = + groupFeaturesByPrediction(rawResult, featuresCol, predictionCol); + Assert.assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + } + + @Test + public void testParam() { + OnlineKMeans onlineKMeans = new OnlineKMeans(); + Assert.assertEquals("features", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(EuclideanDistanceMeasure.NAME, onlineKMeans.getDistanceMeasure()); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(2, onlineKMeans.getK()); + Assert.assertEquals(1, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(1, onlineKMeans.getBatchSize()); + Assert.assertEquals(0., onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("random", onlineKMeans.getInitMode()); + Assert.assertEquals(OnlineKMeans.class.getName().hashCode(), onlineKMeans.getSeed()); + + onlineKMeans + .setK(9) + .setFeaturesCol("test_feature") + .setPredictionCol("test_prediction") + .setK(3) + .setDims(5) + .setBatchSize(5) + .setDecayFactor(0.25) + .setInitMode("direct") + .setSeed(100); + + Assert.assertEquals("test_feature", onlineKMeans.getFeaturesCol()); + Assert.assertEquals("test_prediction", onlineKMeans.getPredictionCol()); + Assert.assertEquals(3, onlineKMeans.getK()); + Assert.assertEquals(5, onlineKMeans.getDims()); + Assert.assertEquals("count", onlineKMeans.getBatchStrategy()); + Assert.assertEquals(5, onlineKMeans.getBatchSize()); + Assert.assertEquals(0.25, onlineKMeans.getDecayFactor(), 1e-5); + Assert.assertEquals("direct", onlineKMeans.getInitMode()); + Assert.assertEquals(100, onlineKMeans.getSeed()); + } + + @Test + public void testFitAndPredict() throws Exception { + OnlineKMeans onlineKMeans = + new OnlineKMeans() + .setInitMode("random") + .setDims(2) + .setInitWeights(new Double[] {0., 0.}) + .setBatchSize(6) + .setFeaturesCol("features") + .setPredictionCol("prediction"); + OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable); + configTransformAndSink(onlineModel); + + miniCluster.submitJob(env.getStreamGraph().getJobGraph()); + waitInitModelDataSetup(); + + trainSource.offerAll(trainData1); + waitModelDataUpdate(); + predictAndAssert( + expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + + trainSource.offerAll(trainData2); + waitModelDataUpdate(); + predictAndAssert( + expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + } + + @Test + public void testInitWithKMeans() throws Exception { + KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction"); + KMeansModel kMeansModel = kMeans.fit(offlineTrainTable); + + OnlineKMeans onlineKMeans = + new OnlineKMeans(kMeansModel.getModelData()) + .setFeaturesCol("features") + .setPredictionCol("prediction") + .setInitMode("direct") + .setDims(2) + .setInitWeights(new Double[] {0., 0.}) + .setBatchSize(6); + + OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable); + configTransformAndSink(onlineModel); + + miniCluster.submitJob(env.getStreamGraph().getJobGraph()); + waitInitModelDataSetup(); + predictAndAssert( + expectedGroups1, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + + trainSource.offerAll(trainData2); + waitModelDataUpdate(); + predictAndAssert( + expectedGroups2, onlineKMeans.getFeaturesCol(), onlineKMeans.getPredictionCol()); + } + + @Test + public void testDecayFactor() throws Exception { + KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction"); + KMeansModel kMeansModel = kMeans.fit(offlineTrainTable); + + OnlineKMeans onlineKMeans = + new OnlineKMeans(kMeansModel.getModelData()) + .setDims(2) + .setInitWeights(new Double[] {3., 3.}) + .setDecayFactor(0.5) + .setBatchSize(6) + .setFeaturesCol("features") + .setPredictionCol("prediction"); + OnlineKMeansModel onlineModel = onlineKMeans.fit(trainTable); + configTransformAndSink(onlineModel); + + miniCluster.submitJob(env.getStreamGraph().getJobGraph()); + modelDataSink.poll(); + + trainSource.offerAll(trainData2); + KMeansModelData actualModelData = modelDataSink.poll(); + + KMeansModelData expectedModelData = + new KMeansModelData( + new DenseVector[] { + Vectors.dense(10.1, 200.3 / 3), Vectors.dense(-10.2, -200.2 / 3) + }); + + Assert.assertEquals(expectedModelData.centroids.length, actualModelData.centroids.length); + Arrays.sort(actualModelData.centroids, (o1, o2) -> (int) (o2.values[0] - o1.values[0])); + for (int i = 0; i < expectedModelData.centroids.length; i++) { + Assert.assertArrayEquals( + expectedModelData.centroids[i].values, + actualModelData.centroids[i].values, + 1e-5); + } + } + + @Test + public void testSaveAndReload() throws Exception { + KMeans kMeans = new KMeans().setFeaturesCol("features").setPredictionCol("prediction"); + KMeansModel kMeansModel = kMeans.fit(offlineTrainTable); + + OnlineKMeans onlineKMeans = + new OnlineKMeans(kMeansModel.getModelData()) + .setFeaturesCol("features") + .setPredictionCol("prediction") + .setInitMode("direct") + .setDims(2) + .setInitWeights(new Double[] {0., 0.}) + .setBatchSize(6); + + String savePath = tempFolder.newFolder().getAbsolutePath(); + onlineKMeans.save(savePath); + miniCluster.executeJobBlocking(env.getStreamGraph().getJobGraph()); + OnlineKMeans loadedKMeans = OnlineKMeans.load(env, savePath); + + OnlineKMeansModel onlineModel = loadedKMeans.fit(trainTable); Review comment: nits: would it be a bit more readable to rename this variable as `model`, so that its name is more consistent with the `loadedModel` used below? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java ########## @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; + +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link SourceFunction} implementation that can directly receive records from tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySourceFunction<T> extends RichSourceFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<T> queue; + private boolean isRunning = true; + + public InMemorySourceFunction() { + id = UUID.randomUUID(); + queue = new LinkedBlockingQueue(); + queueMap.put(id, queue); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + queue = queueMap.get(id); + } + + @Override + public void close() throws Exception { + super.close(); + queueMap.remove(id); + } + + @Override + public void run(SourceContext<T> context) { + while (isRunning) { + T value = queue.poll(); + if (value == null) { + Thread.yield(); + } else { + context.collect(value); + } + } + } + + @Override + public void cancel() { + isRunning = false; + } + + @SafeVarargs + public final void offerAll(T... values) throws InterruptedException { + for (T value : values) { + offer(value); + } + } + + public void offer(T value) throws InterruptedException { + offer(value, 1, TimeUnit.MINUTES); Review comment: Given that we don't limit the capacity of this queue, would it be simpler to call `queue.add(value)` here? If we agree to do this, we will also need to rename/remove those `offer*` methods as appropriate. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
