lindong28 commented on a change in pull request #70: URL: https://github.com/apache/flink-ml/pull/70#discussion_r835764832
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySinkFunction.java ########## @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.sink.RichSinkFunction; +import org.apache.flink.streaming.api.functions.sink.SinkFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +/** A {@link SinkFunction} implementation that makes all collected records available for tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySinkFunction<T> extends RichSinkFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<T> queue; + + public InMemorySinkFunction() { + id = UUID.randomUUID(); + queue = new LinkedBlockingQueue(); + queueMap.put(id, queue); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + queue = queueMap.get(id); + } + + @Override + public void close() throws Exception { + super.close(); + queueMap.remove(id); + } + + @Override + public void invoke(T value, Context context) { + if (!queue.offer(value)) { + throw new RuntimeException( + "Failed to offer " + value + " to blocking queue " + id + "."); + } + } + + public List<T> poll(int num) throws InterruptedException { + List<T> result = new ArrayList<>(); + for (int i = 0; i < num; i++) { + result.add(poll()); + } + return result; + } + + public T poll() throws InterruptedException { + return poll(1, TimeUnit.MINUTES); + } + + public T poll(long timeout, TimeUnit unit) throws InterruptedException { Review comment: nits: would it be simpler to remove this method and move its content to `T poll()`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/InMemorySourceFunction.java ########## @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.configuration.Configuration; +import org.apache.flink.streaming.api.functions.source.RichSourceFunction; +import org.apache.flink.streaming.api.functions.source.SourceFunction; +import org.apache.flink.util.Preconditions; + +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; + +/** A {@link SourceFunction} implementation that can directly receive records from tests. */ +@SuppressWarnings({"unchecked", "rawtypes"}) +public class InMemorySourceFunction<T> extends RichSourceFunction<T> { + private static final Map<UUID, BlockingQueue> queueMap = new ConcurrentHashMap<>(); + private final UUID id; + private BlockingQueue<Optional<T>> queue; + private volatile boolean isRunning = true; + + public InMemorySourceFunction() { + id = UUID.randomUUID(); + queue = new LinkedBlockingQueue(); + queueMap.put(id, queue); + } + + @Override + public void open(Configuration parameters) throws Exception { + super.open(parameters); + queue = queueMap.get(id); + } + + @Override + public void close() throws Exception { + super.close(); + queueMap.remove(id); + } + + @Override + public void run(SourceContext<T> context) throws InterruptedException { + while (isRunning) { + Optional<T> maybeValue = queue.take(); + if (!maybeValue.isPresent()) { + continue; Review comment: Given that this can only happen after `cancel()` is invoked, should it be `break`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java ########## @@ -177,11 +177,20 @@ public void testFewerDistinctPointsThanCluster() { KMeans kmeans = new KMeans().setK(2); KMeansModel model = kmeans.fit(input); Table output = model.transform(input)[0]; - List<Set<DenseVector>> expectedGroups = - Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1))); - List<Set<DenseVector>> actualGroups = - executeAndCollect(output, kmeans.getFeaturesCol(), kmeans.getPredictionCol()); - assertTrue(CollectionUtils.isEqualCollection(expectedGroups, actualGroups)); + + try { Review comment: Hmm... it appears that the behavior of the KMeans algorithm is changed. Spark's `org.apache.spark.mllib.clustering` has a test named `fewer distinct points than clusters`, which would not throw exception when there are less unique points than `K`. Could we keep the same behavior here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay + * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are + * determined entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getK(), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return onlineKMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final int k; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.k = k; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); Review comment: Would it be useful to explicitly check and throw exception if `parallelism <= batchSize`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansUtils.java ########## @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; + +import java.util.Arrays; +import java.util.Random; + +/** Utility methods for KMeans algorithm. */ +public class KMeansUtils { + /** + * Generates a Table containing a {@link KMeansModelData} instance with randomly generated + * centroids. + * + * @param env The environment where to create the table. + * @param k The number of generated centroids. + * @param dim The size of generated centroids. + * @param weight The weight of the centroids. + * @param seed Random seed. + */ + public static Table generateRandomModelData( Review comment: Would it be simpler to move this method to `KMeansModelData`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java ########## @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.ml.common.param.HasBatchStrategy; +import org.apache.flink.ml.common.param.HasDecayFactor; +import org.apache.flink.ml.common.param.HasGlobalBatchSize; +import org.apache.flink.ml.common.param.HasSeed; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params of {@link OnlineKMeans}. + * + * @param <T> The class type of this instance. + */ +public interface OnlineKMeansParams<T> + extends HasBatchStrategy<T>, + HasGlobalBatchSize<T>, + HasDecayFactor<T>, + HasSeed<T>, + KMeansModelParams<T> { + Param<Integer> DIM = Review comment: We can remove `DIM` and `INIT_WEIGHT`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay + * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are + * determined entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getK(), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return onlineKMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final int k; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.k = k; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<KMeansModelData> newModelData = + points.countWindowAll(batchSize) + .apply(new GlobalBatchCreator()) + .flatMap(new GlobalBatchSplitter(parallelism)) + .rebalance() + .connect(modelData.broadcast()) + .transform( + "ModelDataLocalUpdater", + TypeInformation.of(KMeansModelData.class), + new ModelDataLocalUpdater(distanceMeasure, k, decayFactor)) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce(new ModelDataGlobalReducer()); + + return new IterationBodyResult( + DataStreamList.of(newModelData), DataStreamList.of(modelData)); + } + } + + /** + * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight + * average of collected model data. + */ + private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> { + @Override + public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { + DenseVector weights = modelData.weights; + DenseVector[] centroids = modelData.centroids; + DenseVector newWeights = newModelData.weights; + DenseVector[] newCentroids = newModelData.centroids; + + int k = newCentroids.length; + int dim = newCentroids[0].size(); + + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + centroids[i].values[j] = + (centroids[i].values[j] * weights.values[i] + + newCentroids[i].values[j] * newWeights.values[i]) + / Math.max(weights.values[i] + newWeights.values[i], 1e-16); + } + weights.values[i] += newWeights.values[i]; + } + + return new KMeansModelData(centroids, weights); + } + } + + /** + * An operator that updates KMeans model data locally. It mainly does the following operations. + * + * <ul> + * <li>Finds the closest centroid id (cluster) of the input points + * <li>Computes the new centroids from the average of input points that belongs to the same + * cluster + * <li>Computes the weighted average of current and new centroids. The weight of a new + * centroid is the number of input points that belong to this cluster. The weight of a + * current centroid is its original weight scaled by $ decayFactor / parallelism $. + * <li>Generates new model data from the weighted average of centroids, and the sum of + * weights. + * </ul> + */ + private static class ModelDataLocalUpdater extends AbstractStreamOperator<KMeansModelData> + implements TwoInputStreamOperator<DenseVector[], KMeansModelData, KMeansModelData> { + private final DistanceMeasure distanceMeasure; + private final int k; + private final double decayFactor; + private ListState<DenseVector[]> localBatchDataState; + private ListState<KMeansModelData> modelDataState; + + private ModelDataLocalUpdater(DistanceMeasure distanceMeasure, int k, double decayFactor) { + this.distanceMeasure = distanceMeasure; + this.k = k; + this.decayFactor = decayFactor; + } + + @Override + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + + TypeInformation<DenseVector[]> type = + ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE); + localBatchDataState = + context.getOperatorStateStore() + .getListState(new ListStateDescriptor<>("localBatch", type)); + + modelDataState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>("modelData", KMeansModelData.class)); + } + + @Override + public void processElement1(StreamRecord<DenseVector[]> pointsRecord) throws Exception { + localBatchDataState.add(pointsRecord.getValue()); + alignAndComputeModelData(); + } + + @Override + public void processElement2(StreamRecord<KMeansModelData> modelDataRecord) + throws Exception { + Preconditions.checkArgument(modelDataRecord.getValue().centroids.length == k); + modelDataState.add(modelDataRecord.getValue()); + alignAndComputeModelData(); + } + + private void alignAndComputeModelData() throws Exception { + if (!modelDataState.get().iterator().hasNext() + || !localBatchDataState.get().iterator().hasNext()) { + return; + } + + KMeansModelData modelData = + OperatorStateUtils.getUniqueElement(modelDataState, "modelData") + .orElseThrow((Supplier<Exception>) NullPointerException::new); Review comment: Given that `modelDataState.get().iterator().hasNext() == true`, the `orElseThrow()` should not be triggered, right? Would it be simpler to just call `get()` here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay + * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are + * determined entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getK(), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { Review comment: Currently `fit()` requires users to always explicitly call `setInitialModelData()` before `fit()` is called. Would it be simpler to also require user to call `setInitialModelData()` before calling `save()`? If not, it would not be clear to whether user can call `fit()` right after an `OnlineKMeansModel` is loaded. Same for `load()`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java ########## @@ -0,0 +1,407 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo; +import org.apache.flink.iteration.DataStreamList; +import org.apache.flink.iteration.IterationBody; +import org.apache.flink.iteration.IterationBodyResult; +import org.apache.flink.iteration.Iterations; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.TwoInputStreamOperator; +import org.apache.flink.streaming.api.windowing.windows.GlobalWindow; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.collections.IteratorUtils; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +/** + * OnlineKMeans extends the function of {@link KMeans}, supporting to train a K-Means model + * continuously according to an unbounded stream of train data. + * + * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, generalized to incorporate + * forgetfulness (i.e. decay). After the centroids estimated on the current batch are acquired, + * OnlineKMeans computes the new centroids from the weighted average between the original and the + * estimated centroids. The weight of the estimated centroids is the number of points assigned to + * them. The weight of the original centroids is also the number of points, but additionally + * multiplying with the decay factor. + * + * <p>The decay factor scales the contribution of the clusters as estimated thus far. If the decay + * factor is 1, all batches are weighted equally. If the decay factor is 0, new centroids are + * determined entirely by recent data. Lower values correspond to more forgetting. + */ +public class OnlineKMeans + implements Estimator<OnlineKMeans, OnlineKMeansModel>, OnlineKMeansParams<OnlineKMeans> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table initModelDataTable; + + public OnlineKMeans() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + DataStream<DenseVector> points = + tEnv.toDataStream(inputs[0]).map(new FeaturesExtractor(getFeaturesCol())); + + DataStream<KMeansModelData> initModelData = + KMeansModelData.getModelDataStream(initModelDataTable); + initModelData.getTransformation().setParallelism(1); + + IterationBody body = + new OnlineKMeansIterationBody( + DistanceMeasure.getInstance(getDistanceMeasure()), + getK(), + getDecayFactor(), + getGlobalBatchSize()); + + DataStream<KMeansModelData> onlineModelData = + Iterations.iterateUnboundedStreams( + DataStreamList.of(initModelData), DataStreamList.of(points), body) + .get(0); + + Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData); + OnlineKMeansModel model = new OnlineKMeansModel().setModelData(onlineModelDataTable); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + /** Saves the metadata AND bounded model data table (if exists) to the given path. */ + @Override + public void save(String path) throws IOException { + if (initModelDataTable != null) { + ReadWriteUtils.saveModelData( + KMeansModelData.getModelDataStream(initModelDataTable), + path, + new KMeansModelData.ModelDataEncoder()); + } + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OnlineKMeans load(StreamExecutionEnvironment env, String path) + throws IOException { + OnlineKMeans onlineKMeans = ReadWriteUtils.loadStageParam(path); + + String initModelDataPath = ReadWriteUtils.getDataPath(path); + if (Files.exists(Paths.get(initModelDataPath))) { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + DataStream<KMeansModelData> initModelDataStream = + ReadWriteUtils.loadModelData(env, path, new KMeansModelData.ModelDataDecoder()); + + onlineKMeans.initModelDataTable = tEnv.fromDataStream(initModelDataStream); + } + + return onlineKMeans; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class OnlineKMeansIterationBody implements IterationBody { + private final DistanceMeasure distanceMeasure; + private final int k; + private final double decayFactor; + private final int batchSize; + + public OnlineKMeansIterationBody( + DistanceMeasure distanceMeasure, int k, double decayFactor, int batchSize) { + this.distanceMeasure = distanceMeasure; + this.k = k; + this.decayFactor = decayFactor; + this.batchSize = batchSize; + } + + @Override + public IterationBodyResult process( + DataStreamList variableStreams, DataStreamList dataStreams) { + DataStream<KMeansModelData> modelData = variableStreams.get(0); + DataStream<DenseVector> points = dataStreams.get(0); + + int parallelism = points.getParallelism(); + + DataStream<KMeansModelData> newModelData = + points.countWindowAll(batchSize) + .apply(new GlobalBatchCreator()) + .flatMap(new GlobalBatchSplitter(parallelism)) + .rebalance() + .connect(modelData.broadcast()) + .transform( + "ModelDataLocalUpdater", + TypeInformation.of(KMeansModelData.class), + new ModelDataLocalUpdater(distanceMeasure, k, decayFactor)) + .setParallelism(parallelism) + .countWindowAll(parallelism) + .reduce(new ModelDataGlobalReducer()); + + return new IterationBodyResult( + DataStreamList.of(newModelData), DataStreamList.of(modelData)); + } + } + + /** + * Operator that collects a KMeansModelData from each upstream subtask, and outputs the weight + * average of collected model data. + */ + private static class ModelDataGlobalReducer implements ReduceFunction<KMeansModelData> { + @Override + public KMeansModelData reduce(KMeansModelData modelData, KMeansModelData newModelData) { + DenseVector weights = modelData.weights; + DenseVector[] centroids = modelData.centroids; + DenseVector newWeights = newModelData.weights; + DenseVector[] newCentroids = newModelData.centroids; + + int k = newCentroids.length; + int dim = newCentroids[0].size(); + + for (int i = 0; i < k; i++) { + for (int j = 0; j < dim; j++) { + centroids[i].values[j] = + (centroids[i].values[j] * weights.values[i] + + newCentroids[i].values[j] * newWeights.values[i]) + / Math.max(weights.values[i] + newWeights.values[i], 1e-16); + } + weights.values[i] += newWeights.values[i]; + } + + return new KMeansModelData(centroids, weights); + } + } + + /** + * An operator that updates KMeans model data locally. It mainly does the following operations. + * + * <ul> + * <li>Finds the closest centroid id (cluster) of the input points Review comment: nits: `.` is missing at the end of the sentence. Same for the following sentence. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansModel.java ########## @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.clustering.kmeans; + +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.metrics.Gauge; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.distance.DistanceMeasure; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.co.CoProcessFunction; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * OnlineKMeansModel can be regarded as an advanced {@link KMeansModel} operator which can update + * model data in a streaming format, using the model data provided by {@link OnlineKMeans}. + */ +public class OnlineKMeansModel + implements Model<OnlineKMeansModel>, KMeansModelParams<OnlineKMeansModel> { + public static final String MODEL_DATA_VERSION_GAUGE_KEY = "modelDataVersion"; + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public OnlineKMeansModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OnlineKMeansModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + + DataStream<Row> predictionResult = + KMeansModelData.getModelDataStream(modelDataTable) + .broadcast() + .connect(tEnv.toDataStream(inputs[0])) + .process( + new PredictLabelFunction( + getFeaturesCol(), + DistanceMeasure.getInstance(getDistanceMeasure()), + getK()), + outputTypeInfo); + + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends CoProcessFunction<KMeansModelData, Row, Row> { + private final String featuresCol; + + private final DistanceMeasure distanceMeasure; + + private final int k; + + private DenseVector[] centroids; + + // TODO: replace this with a complete solution of reading first model data from unbounded + // model data stream before processing the first predict data. + private final List<Row> bufferedPoints = new ArrayList<>(); + + /** + * Basic implementation of the model data version with the following rules. + * + * <ul> + * <li>Negative value is regarded as illegal value. + * <li>Zero value means the version has not been initialized yet. + * <li>Positive value represents valid version. + * <li>A larger value represents a newer version. Review comment: Given that this is `version`, it seems unnecessary to mention `A larger value represents a newer version`. Would it be simpler to remove this sentence? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
