zhipeng93 commented on code in PR #148:
URL: https://github.com/apache/flink-ml/pull/148#discussion_r958215741


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up 
approach. Each
+ * observation starts in its own cluster and the clusters are merged together 
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by 
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+        implements AlgoOperator<AgglomerativeClustering>,
+                AgglomerativeClusteringParams<AgglomerativeClustering> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public AgglomerativeClustering() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        Integer k = getK();
+        Double distanceThreshold = getDistanceThreshold();
+        Preconditions.checkArgument(
+                (k == null && distanceThreshold != null)
+                        || (k != null && distanceThreshold == null),
+                "One of param k and distanceThreshold should be null.");
+
+        if (getLinkage().equals(LINKAGE_WARD)) {
+            String distanceMeasure = getDistanceMeasure();
+            Preconditions.checkArgument(
+                    distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+                    distanceMeasure
+                            + " was provided as distance measure while linkage 
was ward. Ward only works with euclidean.");
+        }
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.INT),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        OutputTag<Tuple4<Integer, Integer, Double, Integer>> 
mergeInfoOutputTag =
+                new OutputTag<Tuple4<Integer, Integer, Double, 
Integer>>("MERGE_INFO") {};
+
+        SingleOutputStreamOperator<Row> output =
+                dataStream.transform(
+                        "doLocalAgglomerativeClustering",
+                        outputTypeInfo,
+                        new LocalAgglomerativeClusteringOperator(
+                                getFeaturesCol(),
+                                getLinkage(),
+                                getDistanceMeasure(),
+                                getK(),
+                                getDistanceThreshold(),
+                                getComputeFullTree(),
+                                mergeInfoOutputTag));
+        output.getTransformation().setParallelism(1);
+
+        Table outputTable = tEnv.fromDataStream(output);
+
+        DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+                output.getSideOutput(mergeInfoOutputTag);
+        Table mergeInfoTable =
+                tEnv.fromDataStream(mergeInfo)
+                        .as("clusterId1", "clusterId2", "distance", 
"sizeOfMergedCluster");
+
+        return new Table[] {outputTable, mergeInfoTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static AgglomerativeClustering load(StreamTableEnvironment env, 
String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    private static class LocalAgglomerativeClusteringOperator extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Row, Row>, BoundedOneInput {
+        private final String featuresCol;
+        private final String linkage;
+        private final DistanceMeasure distanceMeasure;
+        private final Integer k;
+        private final Double distanceThreshold;
+        private final boolean computeFullTree;
+        private final OutputTag<Tuple4<Integer, Integer, Double, Integer>> 
mergeInfoOutputTag;
+
+        /** Cache for the input data. */
+        private List<Row> inputList;
+        /** State for the input data. */
+        private ListState<Row> inputListState;
+        /** Cluster id of each data point in inputList. */
+        private int[] clusterIds;
+        /** Precomputes the norm of each vector for performance. */
+        private VectorWithNorm[] vectorWithNorms;
+        /** Next cluster Id to be assigned. */
+        private int nextClusterId = 0;
+
+        public LocalAgglomerativeClusteringOperator(
+                String featuresCol,
+                String linkage,
+                String distanceMeasureName,
+                Integer k,
+                Double distanceThreshold,
+                boolean computeFullTree,
+                OutputTag<Tuple4<Integer, Integer, Double, Integer>> 
mergeInfoOutputTag) {
+            this.featuresCol = featuresCol;
+            this.linkage = linkage;
+            this.k = k;
+            this.distanceThreshold = distanceThreshold;
+            this.computeFullTree = computeFullTree;
+            this.mergeInfoOutputTag = mergeInfoOutputTag;
+
+            distanceMeasure = DistanceMeasure.getInstance(distanceMeasureName);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+            inputListState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("inputListState", Row.class));
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> input) throws Exception {
+            inputListState.add(input.getValue());
+        }
+
+        @Override
+        @SuppressWarnings("unchecked")
+        public void endInput() throws Exception {
+            inputList = IteratorUtils.toList(inputListState.get().iterator());
+            int numDataPoints = inputList.size();
+
+            // Assigns initial cluster Ids.
+            clusterIds = new int[numDataPoints];
+            for (int i = 0; i < numDataPoints; i++) {
+                clusterIds[i] = getNextClusterId();
+            }
+
+            List<Cluster> activeClusters = new ArrayList<>();
+            for (int i = 0; i < numDataPoints; i++) {
+                List<Integer> dataPointIds = new ArrayList<>();
+                dataPointIds.add(i);
+                activeClusters.add(new Cluster(i, dataPointIds));
+            }
+
+            // Precomputes vector norms for faster computation.
+            vectorWithNorms = new VectorWithNorm[inputList.size()];
+            for (int i = 0; i < numDataPoints; i++) {
+                vectorWithNorms[i] =
+                        new VectorWithNorm((Vector) 
inputList.get(i).getField(featuresCol));
+            }
+
+            // Clustering process.
+            doClustering(activeClusters);
+
+            // Remaps the cluster Ids and output results.
+            HashMap<Integer, Integer> remappedClusterIds = new HashMap<>();
+            int cnt = 0;
+            for (int i = 0; i < clusterIds.length; i++) {
+                int clusterId = clusterIds[i];
+                if (remappedClusterIds.containsKey(clusterId)) {
+                    clusterIds[i] = remappedClusterIds.get(clusterId);
+                } else {
+                    clusterIds[i] = cnt;
+                    remappedClusterIds.put(clusterId, cnt++);
+                }
+            }
+
+            for (int i = 0; i < numDataPoints; i++) {
+                output.collect(
+                        new StreamRecord<>(Row.join(inputList.get(i), 
Row.of(clusterIds[i]))));
+            }
+        }
+
+        private int getNextClusterId() {
+            return nextClusterId++;
+        }
+
+        private void doClustering(List<Cluster> activeClusters) {
+            int clusterOffset1 = -1, clusterOffset2 = -1;
+            boolean clusteringRunning =
+                    (k != null && activeClusters.size() > k) || 
(distanceThreshold != null);
+
+            while (clusteringRunning || (computeFullTree && 
activeClusters.size() > 1)) {
+                // Computes the distance between two clusters.
+                double minDistance = Double.MAX_VALUE;
+                for (int i = 0; i < activeClusters.size(); i++) {
+                    for (int j = i + 1; j < activeClusters.size(); j++) {
+                        double distance =
+                                computeDistanceBetweenClusters(
+                                        activeClusters.get(i), 
activeClusters.get(j));
+                        if (distance < minDistance) {
+                            minDistance = distance;
+                            clusterOffset1 = i;
+                            clusterOffset2 = j;
+                        }
+                    }
+                }
+
+                // Outputs the merge info.
+                Cluster cluster1 = activeClusters.get(clusterOffset1);
+                Cluster cluster2 = activeClusters.get(clusterOffset2);
+                int clusterId1 = cluster1.clusterId;
+                int clusterId2 = cluster2.clusterId;
+                output.collect(
+                        mergeInfoOutputTag,
+                        new StreamRecord<>(
+                                Tuple4.of(
+                                        Math.min(clusterId1, clusterId2),
+                                        Math.max(clusterId1, clusterId2),
+                                        minDistance,
+                                        cluster1.dataPointIds.size()
+                                                + 
cluster2.dataPointIds.size())));
+
+                // Merges these two clusters.
+                Cluster mergedCluster =
+                        new Cluster(
+                                getNextClusterId(), cluster1.dataPointIds, 
cluster2.dataPointIds);
+                activeClusters.set(clusterOffset1, mergedCluster);
+                activeClusters.remove(clusterOffset2);
+
+                // Updates cluster Ids for each data point if clustering is 
still running.
+                if (clusteringRunning) {
+                    int mergedClusterId = mergedCluster.clusterId;
+                    for (int dataPointId : mergedCluster.dataPointIds) {
+                        clusterIds[dataPointId] = mergedClusterId;
+                    }
+                }
+
+                clusteringRunning =

Review Comment:
   No, we need to merge the cluster in this round first.
   
   Say when entering the loop, the number of active clusters is `K+1`. If we 
change the order as above, the number of output clusters will be `k+1`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to