zhipeng93 commented on code in PR #171:
URL: https://github.com/apache/flink-ml/pull/171#discussion_r1036626947


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -249,175 +264,223 @@ public void process(
             for (int i = 0; i < numDataPoints; i++) {
                 output.collect(Row.join(inputList.get(i), 
Row.of(clusterIds[i])));
             }
+
+            // Outputs the merge info.
+            if (computeFullTree) {
+                stoppedIdx = nnChain.size();
+            }
+            for (int i = 0; i < stoppedIdx; i++) {
+                Tuple4<Integer, Integer, Integer, Double> mergeItem = 
nnChain.get(i);
+                int cid1 = Math.min(mergeItem.f0, mergeItem.f1);
+                int cid2 = Math.max(mergeItem.f0, mergeItem.f1);
+                context.output(
+                        mergeInfoOutputTag,
+                        Tuple4.of(
+                                cid1,
+                                cid2,
+                                mergeItem.f3,
+                                nnChainAndSize.f1[cid1] + 
nnChainAndSize.f1[cid2]));
+            }
         }
 
-        private int getNextClusterId() {
-            return nextClusterId++;
+        /** Reorders the nearest-neighbor-chain. */
+        private void reOrderNnChain(List<Tuple4<Integer, Integer, Integer, 
Double>> nnChain) {
+            int nextClusterId = nnChain.size() + 1;
+            HashMap<Integer, Integer> nodeMapping = new HashMap<>();
+            for (Tuple4<Integer, Integer, Integer, Double> t : nnChain) {
+                if (nodeMapping.containsKey(t.f0)) {
+                    t.f0 = nodeMapping.get(t.f0);
+                }
+                if (nodeMapping.containsKey(t.f1)) {
+                    t.f1 = nodeMapping.get(t.f1);
+                }
+                nodeMapping.put(t.f2, nextClusterId);
+                nextClusterId++;
+            }
         }
 
-        private void doClustering(
-                List<Cluster> activeClusters,
-                ProcessAllWindowFunction<Row, Row, ?>.Context context) {
-            boolean clusteringRunning =
-                    (numCluster != null && activeClusters.size() > numCluster)
-                            || (distanceThreshold != null);
-
-            while (clusteringRunning || (computeFullTree && 
activeClusters.size() > 1)) {
-                int clusterOffset1 = -1, clusterOffset2 = -1;
-                // Computes the distance between two clusters.
-                double minDistance = Double.MAX_VALUE;
-                for (int i = 0; i < activeClusters.size(); i++) {
-                    for (int j = i + 1; j < activeClusters.size(); j++) {
-                        double distance =
-                                computeDistanceBetweenClusters(
-                                        activeClusters.get(i), 
activeClusters.get(j));
-                        if (distance < minDistance) {
-                            minDistance = distance;
-                            clusterOffset1 = i;
-                            clusterOffset2 = j;
+        /** Converts the cluster Ids for each input data point. */
+        private int[] label(
+                List<Tuple4<Integer, Integer, Integer, Double>> nnChains, int 
numDataPoints) {
+            UnionFind unionFind = new UnionFind(numDataPoints);
+            for (Tuple4<Integer, Integer, Integer, Double> t : nnChains) {
+                unionFind.union(unionFind.find(t.f0), unionFind.find(t.f1));
+            }
+            int[] clusterIds = new int[numDataPoints];
+            for (int i = 0; i < clusterIds.length; i++) {
+                clusterIds[i] = unionFind.find(i);
+            }
+            return clusterIds;
+        }
+
+        /** The main logic of nearest-neighbor-chain algorithm. */
+        private Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> 
nnChainCore(
+                HashSet<Integer> nodeLabels, DistanceMatrix distanceMatrix, 
String linkage) {
+            int numDataPoints = nodeLabels.size();
+            int nextClusterId = numDataPoints;
+            List<Tuple4<Integer, Integer, Integer, Double>> nnChain =
+                    new ArrayList<>(numDataPoints);
+            List<Integer> chain = new ArrayList<>();
+            int[] size = new int[numDataPoints * 2 - 1];
+            for (int i = 0; i < numDataPoints; i++) {
+                size[i] = 1;
+            }
+
+            int a, b;
+            while (nodeLabels.size() > 1) {
+                if (chain.size() <= 3) {
+                    Iterator<Integer> iterator = nodeLabels.iterator();
+                    a = iterator.next();
+                    chain.clear();
+                    chain.add(a);
+                    b = iterator.next();
+                } else {
+                    int chainSize = chain.size();
+                    a = chain.get(chainSize - 4);
+                    b = chain.get(chainSize - 3);
+                    chain.remove(chainSize - 1);
+                    chain.remove(chainSize - 2);
+                    chain.remove(chainSize - 3);
+                }
+
+                while (chain.size() < 3 || chain.get(chain.size() - 3) != a) {
+                    double minDistance = Double.MAX_VALUE;
+                    int c = -1;
+                    for (int x : nodeLabels) {
+                        if (x == a) {
+                            continue;
+                        }
+                        double dax = distanceMatrix.get(a, x);
+                        if (dax < minDistance) {
+                            c = x;
+                            minDistance = dax;
                         }
                     }
+                    if (minDistance == distanceMatrix.get(a, b) && 
nodeLabels.contains(b)) {
+                        c = b;
+                    }
+                    b = a;
+                    a = c;
+                    chain.add(a);
                 }
 
-                // Outputs the merge info.
-                Cluster cluster1 = activeClusters.get(clusterOffset1);
-                Cluster cluster2 = activeClusters.get(clusterOffset2);
-                int clusterId1 = cluster1.clusterId;
-                int clusterId2 = cluster2.clusterId;
-                context.output(
-                        mergeInfoOutputTag,
-                        Tuple4.of(
-                                Math.min(clusterId1, clusterId2),
-                                Math.max(clusterId1, clusterId2),
-                                minDistance,
-                                cluster1.dataPointIds.size() + 
cluster2.dataPointIds.size()));
-
-                // Merges these two clusters.
-                Cluster mergedCluster =
-                        new Cluster(
-                                getNextClusterId(), cluster1.dataPointIds, 
cluster2.dataPointIds);
-                activeClusters.set(clusterOffset1, mergedCluster);
-                activeClusters.remove(clusterOffset2);
-
-                // Updates cluster Ids for each data point if clustering is 
still running.
-                if (clusteringRunning) {
-                    int mergedClusterId = mergedCluster.clusterId;
-                    for (int dataPointId : mergedCluster.dataPointIds) {
-                        clusterIds[dataPointId] = mergedClusterId;
-                    }
+                int mergedNodeLabel = nextClusterId;
+                nnChain.add(Tuple4.of(a, b, mergedNodeLabel, 
distanceMatrix.get(a, b)));
+                nodeLabels.remove(a);
+                nodeLabels.remove(b);
+                nextClusterId++;
+                size[mergedNodeLabel] = size[a] + size[b];
+
+                for (int x : nodeLabels) {
+                    double d =
+                            computeClusterDistances(
+                                    distanceMatrix.get(a, x),
+                                    distanceMatrix.get(b, x),
+                                    distanceMatrix.get(a, b),
+                                    size[a],
+                                    size[b],
+                                    size[x],
+                                    linkage);
+                    distanceMatrix.set(x, mergedNodeLabel, d);
                 }
 
-                clusteringRunning =
-                        (numCluster != null && activeClusters.size() > 
numCluster)
-                                || (distanceThreshold != null
-                                        && distanceThreshold > minDistance
-                                        && activeClusters.size() > 1);
+                nodeLabels.add(mergedNodeLabel);
+            }
+
+            return Tuple2.of(nnChain, size);
+        }
+
+        /** Utility class for finding labels for input data points. */
+        private static class UnionFind {
+            private final int[] parent;
+            private int nextLabel;
+
+            public UnionFind(int numDataPoints) {
+                parent = new int[2 * numDataPoints - 1];
+                Arrays.fill(parent, -1);
+                nextLabel = numDataPoints;
+            }
+
+            public void union(int m, int n) {
+                parent[m] = nextLabel;
+                parent[n] = nextLabel;
+                nextLabel++;
+            }
+
+            public int find(int n) {
+                int p = n;
+                while (parent[n] != -1) {
+                    n = parent[n];
+                }
+                while (parent[p] != n && parent[p] != -1) {
+                    p = parent[p];
+                    parent[p] = n;
+                }
+                return n;
             }
         }
 
-        private double computeDistanceBetweenClusters(Cluster cluster1, 
Cluster cluster2) {
-            double distance;
-            int size1 = cluster1.dataPointIds.size();
-            int size2 = cluster2.dataPointIds.size();
+        /** Utility class for storing distances between every two clusters. */
+        private static class DistanceMatrix {
+            /** The storage of distances between each two clusters. */
+            private final double[] distances;
+            /** Number of clusters. */
+            private final int n;
 
+            public DistanceMatrix(int n) {
+                distances = new double[n * (n - 1) / 2];
+                this.n = n;
+            }
+
+            public void set(int i, int j, double value) {
+                int smallIdx = Math.min(i, j);
+                int bigIdx = Math.max(i, j);
+                int offset = (n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - 
smallIdx - 1);
+                distances[offset] = value;
+            }
+
+            public double get(int i, int j) {
+                int smallIdx = Math.min(i, j);
+                int bigIdx = Math.max(i, j);
+                int offset = (n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - 
smallIdx - 1);
+                return distances[offset];
+            }
+        }
+
+        /**
+         * Computes the distance between cluster k and the new cluster merged 
by cluster i and k.

Review Comment:
   Good catch. It is indeed a type. I have updated the comment as `Computes the 
distance between cluster k and the new cluster merged by cluster i and j`.
   
   Suppose we are going to merge cluster `i` and `j` as a new cluster, this 
function aims to compute the distance between the new cluster and any other 
cluster `k`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -249,175 +264,223 @@ public void process(
             for (int i = 0; i < numDataPoints; i++) {
                 output.collect(Row.join(inputList.get(i), 
Row.of(clusterIds[i])));
             }
+
+            // Outputs the merge info.
+            if (computeFullTree) {
+                stoppedIdx = nnChain.size();
+            }
+            for (int i = 0; i < stoppedIdx; i++) {
+                Tuple4<Integer, Integer, Integer, Double> mergeItem = 
nnChain.get(i);
+                int cid1 = Math.min(mergeItem.f0, mergeItem.f1);
+                int cid2 = Math.max(mergeItem.f0, mergeItem.f1);
+                context.output(
+                        mergeInfoOutputTag,
+                        Tuple4.of(
+                                cid1,
+                                cid2,
+                                mergeItem.f3,
+                                nnChainAndSize.f1[cid1] + 
nnChainAndSize.f1[cid2]));
+            }
         }
 
-        private int getNextClusterId() {
-            return nextClusterId++;
+        /** Reorders the nearest-neighbor-chain. */
+        private void reOrderNnChain(List<Tuple4<Integer, Integer, Integer, 
Double>> nnChain) {
+            int nextClusterId = nnChain.size() + 1;
+            HashMap<Integer, Integer> nodeMapping = new HashMap<>();
+            for (Tuple4<Integer, Integer, Integer, Double> t : nnChain) {
+                if (nodeMapping.containsKey(t.f0)) {
+                    t.f0 = nodeMapping.get(t.f0);
+                }
+                if (nodeMapping.containsKey(t.f1)) {
+                    t.f1 = nodeMapping.get(t.f1);
+                }
+                nodeMapping.put(t.f2, nextClusterId);
+                nextClusterId++;
+            }
         }
 
-        private void doClustering(
-                List<Cluster> activeClusters,
-                ProcessAllWindowFunction<Row, Row, ?>.Context context) {
-            boolean clusteringRunning =
-                    (numCluster != null && activeClusters.size() > numCluster)
-                            || (distanceThreshold != null);
-
-            while (clusteringRunning || (computeFullTree && 
activeClusters.size() > 1)) {
-                int clusterOffset1 = -1, clusterOffset2 = -1;
-                // Computes the distance between two clusters.
-                double minDistance = Double.MAX_VALUE;
-                for (int i = 0; i < activeClusters.size(); i++) {
-                    for (int j = i + 1; j < activeClusters.size(); j++) {
-                        double distance =
-                                computeDistanceBetweenClusters(
-                                        activeClusters.get(i), 
activeClusters.get(j));
-                        if (distance < minDistance) {
-                            minDistance = distance;
-                            clusterOffset1 = i;
-                            clusterOffset2 = j;
+        /** Converts the cluster Ids for each input data point. */
+        private int[] label(
+                List<Tuple4<Integer, Integer, Integer, Double>> nnChains, int 
numDataPoints) {
+            UnionFind unionFind = new UnionFind(numDataPoints);
+            for (Tuple4<Integer, Integer, Integer, Double> t : nnChains) {
+                unionFind.union(unionFind.find(t.f0), unionFind.find(t.f1));
+            }
+            int[] clusterIds = new int[numDataPoints];
+            for (int i = 0; i < clusterIds.length; i++) {
+                clusterIds[i] = unionFind.find(i);
+            }
+            return clusterIds;
+        }
+
+        /** The main logic of nearest-neighbor-chain algorithm. */
+        private Tuple2<List<Tuple4<Integer, Integer, Integer, Double>>, int[]> 
nnChainCore(
+                HashSet<Integer> nodeLabels, DistanceMatrix distanceMatrix, 
String linkage) {
+            int numDataPoints = nodeLabels.size();
+            int nextClusterId = numDataPoints;
+            List<Tuple4<Integer, Integer, Integer, Double>> nnChain =
+                    new ArrayList<>(numDataPoints);
+            List<Integer> chain = new ArrayList<>();
+            int[] size = new int[numDataPoints * 2 - 1];
+            for (int i = 0; i < numDataPoints; i++) {
+                size[i] = 1;
+            }
+
+            int a, b;
+            while (nodeLabels.size() > 1) {
+                if (chain.size() <= 3) {
+                    Iterator<Integer> iterator = nodeLabels.iterator();
+                    a = iterator.next();
+                    chain.clear();
+                    chain.add(a);
+                    b = iterator.next();
+                } else {
+                    int chainSize = chain.size();
+                    a = chain.get(chainSize - 4);
+                    b = chain.get(chainSize - 3);
+                    chain.remove(chainSize - 1);
+                    chain.remove(chainSize - 2);
+                    chain.remove(chainSize - 3);
+                }
+
+                while (chain.size() < 3 || chain.get(chain.size() - 3) != a) {
+                    double minDistance = Double.MAX_VALUE;
+                    int c = -1;
+                    for (int x : nodeLabels) {
+                        if (x == a) {
+                            continue;
+                        }
+                        double dax = distanceMatrix.get(a, x);
+                        if (dax < minDistance) {
+                            c = x;
+                            minDistance = dax;
                         }
                     }
+                    if (minDistance == distanceMatrix.get(a, b) && 
nodeLabels.contains(b)) {
+                        c = b;
+                    }
+                    b = a;
+                    a = c;
+                    chain.add(a);
                 }
 
-                // Outputs the merge info.
-                Cluster cluster1 = activeClusters.get(clusterOffset1);
-                Cluster cluster2 = activeClusters.get(clusterOffset2);
-                int clusterId1 = cluster1.clusterId;
-                int clusterId2 = cluster2.clusterId;
-                context.output(
-                        mergeInfoOutputTag,
-                        Tuple4.of(
-                                Math.min(clusterId1, clusterId2),
-                                Math.max(clusterId1, clusterId2),
-                                minDistance,
-                                cluster1.dataPointIds.size() + 
cluster2.dataPointIds.size()));
-
-                // Merges these two clusters.
-                Cluster mergedCluster =
-                        new Cluster(
-                                getNextClusterId(), cluster1.dataPointIds, 
cluster2.dataPointIds);
-                activeClusters.set(clusterOffset1, mergedCluster);
-                activeClusters.remove(clusterOffset2);
-
-                // Updates cluster Ids for each data point if clustering is 
still running.
-                if (clusteringRunning) {
-                    int mergedClusterId = mergedCluster.clusterId;
-                    for (int dataPointId : mergedCluster.dataPointIds) {
-                        clusterIds[dataPointId] = mergedClusterId;
-                    }
+                int mergedNodeLabel = nextClusterId;
+                nnChain.add(Tuple4.of(a, b, mergedNodeLabel, 
distanceMatrix.get(a, b)));
+                nodeLabels.remove(a);
+                nodeLabels.remove(b);
+                nextClusterId++;
+                size[mergedNodeLabel] = size[a] + size[b];
+
+                for (int x : nodeLabels) {
+                    double d =
+                            computeClusterDistances(
+                                    distanceMatrix.get(a, x),
+                                    distanceMatrix.get(b, x),
+                                    distanceMatrix.get(a, b),
+                                    size[a],
+                                    size[b],
+                                    size[x],
+                                    linkage);
+                    distanceMatrix.set(x, mergedNodeLabel, d);
                 }
 
-                clusteringRunning =
-                        (numCluster != null && activeClusters.size() > 
numCluster)
-                                || (distanceThreshold != null
-                                        && distanceThreshold > minDistance
-                                        && activeClusters.size() > 1);
+                nodeLabels.add(mergedNodeLabel);
+            }
+
+            return Tuple2.of(nnChain, size);
+        }
+
+        /** Utility class for finding labels for input data points. */
+        private static class UnionFind {
+            private final int[] parent;
+            private int nextLabel;
+
+            public UnionFind(int numDataPoints) {
+                parent = new int[2 * numDataPoints - 1];
+                Arrays.fill(parent, -1);
+                nextLabel = numDataPoints;
+            }
+
+            public void union(int m, int n) {
+                parent[m] = nextLabel;
+                parent[n] = nextLabel;
+                nextLabel++;
+            }
+
+            public int find(int n) {
+                int p = n;
+                while (parent[n] != -1) {
+                    n = parent[n];
+                }
+                while (parent[p] != n && parent[p] != -1) {
+                    p = parent[p];
+                    parent[p] = n;
+                }
+                return n;
             }
         }
 
-        private double computeDistanceBetweenClusters(Cluster cluster1, 
Cluster cluster2) {
-            double distance;
-            int size1 = cluster1.dataPointIds.size();
-            int size2 = cluster2.dataPointIds.size();
+        /** Utility class for storing distances between every two clusters. */
+        private static class DistanceMatrix {
+            /** The storage of distances between each two clusters. */
+            private final double[] distances;
+            /** Number of clusters. */
+            private final int n;
 
+            public DistanceMatrix(int n) {
+                distances = new double[n * (n - 1) / 2];
+                this.n = n;
+            }
+
+            public void set(int i, int j, double value) {
+                int smallIdx = Math.min(i, j);
+                int bigIdx = Math.max(i, j);
+                int offset = (n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - 
smallIdx - 1);
+                distances[offset] = value;
+            }
+
+            public double get(int i, int j) {
+                int smallIdx = Math.min(i, j);
+                int bigIdx = Math.max(i, j);
+                int offset = (n * 2 - 1 - smallIdx) * smallIdx / 2 + (bigIdx - 
smallIdx - 1);
+                return distances[offset];
+            }
+        }
+
+        /**
+         * Computes the distance between cluster k and the new cluster merged 
by cluster i and k.

Review Comment:
   Good catch. It is indeed a typo. I have updated the comment as `Computes the 
distance between cluster k and the new cluster merged by cluster i and j`.
   
   Suppose we are going to merge cluster `i` and `j` as a new cluster, this 
function aims to compute the distance between the new cluster and any other 
cluster `k`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to