lindong28 commented on a change in pull request #70:
URL: https://github.com/apache/flink-ml/pull/70#discussion_r832765505



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training 
process in a
+ * single-threaded way. Correctness is not affected but there are performance 
issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {
+        Preconditions.checkArgument(initModelDataTables.length == 1);
+        this.initModelDataTable = initModelDataTables[0];
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+        setInitMode("direct");
+    }
+
+    @Override
+    public OnlineKMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        
Preconditions.checkArgument(HasBatchStrategy.COUNT_STRATEGY.equals(getBatchStrategy()));

Review comment:
       Given that the validator of BATCH_STRATEGY should already guarantee that 
`getBatchStrategy() == count`, would it be simpler to remove this check? This 
would also be more consistent with how NaiveBayes handles modelType.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeansParams.java
##########
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasDecayFactor;
+import org.apache.flink.ml.common.param.HasSeed;
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link OnlineKMeans}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineKMeansParams<T>
+        extends HasBatchStrategy<T>, HasDecayFactor<T>, HasSeed<T>, 
KMeansModelParams<T> {
+    Param<String> INIT_MODE =
+            new StringParam(
+                    "initMode",
+                    "How to initialize the model data of the online KMeans 
algorithm. Supported options: 'random', 'direct'.",
+                    "random",
+                    ParamValidators.inArray("random", "direct"));
+
+    Param<Integer> DIMS =

Review comment:
       This parameter indicate the dimension of the input vector. Would it be 
more intuitive to name it `dim` instead of `dims`?
   
   If we agree to do so, we probably need to rename other related variables as 
well.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/OnlineKMeans.java
##########
@@ -0,0 +1,569 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.kmeans;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import 
org.apache.flink.table.api.bridge.java.internal.StreamTableEnvironmentImpl;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Random;
+
+/**
+ * OnlineKMeans extends the function of {@link KMeans}, supporting to train a 
K-Means model
+ * continuously according to an unbounded stream of train data.
+ *
+ * <p>OnlineKMeans makes updates with the "mini-batch" KMeans rule, 
generalized to incorporate
+ * forgetfulness (i.e. decay). After the centroids estimated on the current 
batch are acquired,
+ * OnlineKMeans computes the new centroids from the weighted average between 
the original and the
+ * estimated centroids. The weight of the estimated centroids is the number of 
points assigned to
+ * them. The weight of the original centroids is also the number of points, 
but additionally
+ * multiplying with the decay factor.
+ *
+ * <p>The decay factor scales the contribution of the clusters as estimated 
thus far. If decay
+ * factor is 1, all batches are weighted equally. If decay factor is 0, new 
centroids are determined
+ * entirely by recent data. Lower values correspond to more forgetting.
+ *
+ * <p>NOTE: This class's current naive implementation performs the training 
process in a
+ * single-threaded way. Correctness is not affected but there are performance 
issues.
+ */
+public class OnlineKMeans
+        implements Estimator<OnlineKMeans, OnlineKMeansModel>, 
OnlineKMeansParams<OnlineKMeans> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public OnlineKMeans() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    public OnlineKMeans(Table... initModelDataTables) {

Review comment:
       Given that we know for sure `initModelDataTables.length == 1`, would it 
be simpler to just use `Table initModelDataTable` as input parameter?
   




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to