yunfengzhou-hub commented on code in PR #148:
URL: https://github.com/apache/flink-ml/pull/148#discussion_r957971688
##########
flink-ml-python/pyflink/ml/lib/tests/test_ml_lib_completeness.py:
##########
@@ -70,7 +70,7 @@ def _load_stages_from_module(cls, module):
'JavaClusteringEstimator', 'JavaClusteringModel',
'JavaEvaluationAlgoOperator', 'JavaFeatureTransformer',
'JavaFeatureEstimator', 'JavaFeatureModel',
- 'JavaRegressionEstimator', 'JavaRegressionModel')]
+ 'JavaRegressionEstimator', 'JavaRegressionModel',
'JavaClusteringAlgoOperator')]
Review Comment:
nit: it might be better to place `JavaClusteringAlgoOperator` between
`JavaClassificationModel` and `JavaClusteringEstimator`.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up
approach. Each
+ * observation starts in its own cluster and the clusters are merged together
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer k = getK();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (k == null && distanceThreshold != null)
+ || (k != null && distanceThreshold == null),
+ "One of param k and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage
was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double,
Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getK(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance",
"sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static AgglomerativeClustering load(StreamTableEnvironment env,
String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ private static class LocalAgglomerativeClusteringOperator extends
AbstractStreamOperator<Row>
+ implements OneInputStreamOperator<Row, Row>, BoundedOneInput {
+ private final String featuresCol;
+ private final String linkage;
+ private final DistanceMeasure distanceMeasure;
+ private final Integer k;
+ private final Double distanceThreshold;
+ private final boolean computeFullTree;
+ private final OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag;
+
+ /** Cache for the input data. */
+ private List<Row> inputList;
Review Comment:
It seems that this variable can be converted to a local variable.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up
approach. Each
+ * observation starts in its own cluster and the clusters are merged together
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer k = getK();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (k == null && distanceThreshold != null)
+ || (k != null && distanceThreshold == null),
+ "One of param k and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage
was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double,
Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getK(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance",
"sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static AgglomerativeClustering load(StreamTableEnvironment env,
String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ private static class LocalAgglomerativeClusteringOperator extends
AbstractStreamOperator<Row>
+ implements OneInputStreamOperator<Row, Row>, BoundedOneInput {
+ private final String featuresCol;
+ private final String linkage;
+ private final DistanceMeasure distanceMeasure;
+ private final Integer k;
+ private final Double distanceThreshold;
+ private final boolean computeFullTree;
+ private final OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag;
+
+ /** Cache for the input data. */
+ private List<Row> inputList;
+ /** State for the input data. */
+ private ListState<Row> inputListState;
+ /** Cluster id of each data point in inputList. */
+ private int[] clusterIds;
+ /** Precomputes the norm of each vector for performance. */
+ private VectorWithNorm[] vectorWithNorms;
+ /** Next cluster Id to be assigned. */
+ private int nextClusterId = 0;
+
+ public LocalAgglomerativeClusteringOperator(
+ String featuresCol,
+ String linkage,
+ String distanceMeasureName,
+ Integer k,
+ Double distanceThreshold,
+ boolean computeFullTree,
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag) {
+ this.featuresCol = featuresCol;
+ this.linkage = linkage;
+ this.k = k;
+ this.distanceThreshold = distanceThreshold;
+ this.computeFullTree = computeFullTree;
+ this.mergeInfoOutputTag = mergeInfoOutputTag;
+
+ distanceMeasure = DistanceMeasure.getInstance(distanceMeasureName);
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ inputListState =
+ context.getOperatorStateStore()
+ .getListState(new
ListStateDescriptor<>("inputListState", Row.class));
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> input) throws Exception {
+ inputListState.add(input.getValue());
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void endInput() throws Exception {
+ inputList = IteratorUtils.toList(inputListState.get().iterator());
+ int numDataPoints = inputList.size();
+
+ // Assigns initial cluster Ids.
+ clusterIds = new int[numDataPoints];
+ for (int i = 0; i < numDataPoints; i++) {
+ clusterIds[i] = getNextClusterId();
+ }
+
+ List<Cluster> activeClusters = new ArrayList<>();
+ for (int i = 0; i < numDataPoints; i++) {
+ List<Integer> dataPointIds = new ArrayList<>();
+ dataPointIds.add(i);
+ activeClusters.add(new Cluster(i, dataPointIds));
+ }
+
+ // Precomputes vector norms for faster computation.
+ vectorWithNorms = new VectorWithNorm[inputList.size()];
+ for (int i = 0; i < numDataPoints; i++) {
+ vectorWithNorms[i] =
+ new VectorWithNorm((Vector)
inputList.get(i).getField(featuresCol));
+ }
+
+ // Clustering process.
+ doClustering(activeClusters);
+
+ // Remaps the cluster Ids and output results.
+ HashMap<Integer, Integer> remappedClusterIds = new HashMap<>();
+ int cnt = 0;
+ for (int i = 0; i < clusterIds.length; i++) {
+ int clusterId = clusterIds[i];
+ if (remappedClusterIds.containsKey(clusterId)) {
+ clusterIds[i] = remappedClusterIds.get(clusterId);
+ } else {
+ clusterIds[i] = cnt;
+ remappedClusterIds.put(clusterId, cnt++);
+ }
+ }
+
+ for (int i = 0; i < numDataPoints; i++) {
+ output.collect(
+ new StreamRecord<>(Row.join(inputList.get(i),
Row.of(clusterIds[i]))));
+ }
+ }
+
+ private int getNextClusterId() {
+ return nextClusterId++;
+ }
+
+ private void doClustering(List<Cluster> activeClusters) {
+ int clusterOffset1 = -1, clusterOffset2 = -1;
+ boolean clusteringRunning =
+ (k != null && activeClusters.size() > k) ||
(distanceThreshold != null);
+
+ while (clusteringRunning || (computeFullTree &&
activeClusters.size() > 1)) {
+ // Computes the distance between two clusters.
+ double minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < activeClusters.size(); i++) {
+ for (int j = i + 1; j < activeClusters.size(); j++) {
+ double distance =
+ computeDistanceBetweenClusters(
+ activeClusters.get(i),
activeClusters.get(j));
+ if (distance < minDistance) {
+ minDistance = distance;
+ clusterOffset1 = i;
+ clusterOffset2 = j;
+ }
+ }
+ }
+
+ // Outputs the merge info.
+ Cluster cluster1 = activeClusters.get(clusterOffset1);
+ Cluster cluster2 = activeClusters.get(clusterOffset2);
+ int clusterId1 = cluster1.clusterId;
+ int clusterId2 = cluster2.clusterId;
+ output.collect(
+ mergeInfoOutputTag,
+ new StreamRecord<>(
+ Tuple4.of(
+ Math.min(clusterId1, clusterId2),
+ Math.max(clusterId1, clusterId2),
+ minDistance,
+ cluster1.dataPointIds.size()
+ +
cluster2.dataPointIds.size())));
+
+ // Merges these two clusters.
+ Cluster mergedCluster =
+ new Cluster(
+ getNextClusterId(), cluster1.dataPointIds,
cluster2.dataPointIds);
+ activeClusters.set(clusterOffset1, mergedCluster);
+ activeClusters.remove(clusterOffset2);
+
+ // Updates cluster Ids for each data point if clustering is
still running.
+ if (clusteringRunning) {
+ int mergedClusterId = mergedCluster.clusterId;
+ for (int dataPointId : mergedCluster.dataPointIds) {
+ clusterIds[dataPointId] = mergedClusterId;
+ }
+ }
+
+ clusteringRunning =
Review Comment:
How about changing the order of the logics here?
```java
clusteringRunning =
(k != null && activeClusters.size() > k)
|| (distanceThreshold != null &&
distanceThreshold > minDistance);
if (clusteringRunning) {
int mergedClusterId = mergedCluster.clusterId;
for (int dataPointId : mergedCluster.dataPointIds) {
clusterIds[dataPointId] = mergedClusterId;
}
}
```
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up
approach. Each
+ * observation starts in its own cluster and the clusters are merged together
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer k = getK();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (k == null && distanceThreshold != null)
+ || (k != null && distanceThreshold == null),
+ "One of param k and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage
was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double,
Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getK(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance",
"sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static AgglomerativeClustering load(StreamTableEnvironment env,
String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ private static class LocalAgglomerativeClusteringOperator extends
AbstractStreamOperator<Row>
+ implements OneInputStreamOperator<Row, Row>, BoundedOneInput {
+ private final String featuresCol;
+ private final String linkage;
+ private final DistanceMeasure distanceMeasure;
+ private final Integer k;
+ private final Double distanceThreshold;
+ private final boolean computeFullTree;
+ private final OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag;
+
+ /** Cache for the input data. */
+ private List<Row> inputList;
+ /** State for the input data. */
+ private ListState<Row> inputListState;
+ /** Cluster id of each data point in inputList. */
+ private int[] clusterIds;
+ /** Precomputes the norm of each vector for performance. */
+ private VectorWithNorm[] vectorWithNorms;
+ /** Next cluster Id to be assigned. */
+ private int nextClusterId = 0;
+
+ public LocalAgglomerativeClusteringOperator(
+ String featuresCol,
+ String linkage,
+ String distanceMeasureName,
+ Integer k,
+ Double distanceThreshold,
+ boolean computeFullTree,
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag) {
+ this.featuresCol = featuresCol;
+ this.linkage = linkage;
+ this.k = k;
+ this.distanceThreshold = distanceThreshold;
+ this.computeFullTree = computeFullTree;
+ this.mergeInfoOutputTag = mergeInfoOutputTag;
+
+ distanceMeasure = DistanceMeasure.getInstance(distanceMeasureName);
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ inputListState =
+ context.getOperatorStateStore()
+ .getListState(new
ListStateDescriptor<>("inputListState", Row.class));
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> input) throws Exception {
+ inputListState.add(input.getValue());
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void endInput() throws Exception {
+ inputList = IteratorUtils.toList(inputListState.get().iterator());
+ int numDataPoints = inputList.size();
+
+ // Assigns initial cluster Ids.
+ clusterIds = new int[numDataPoints];
+ for (int i = 0; i < numDataPoints; i++) {
+ clusterIds[i] = getNextClusterId();
+ }
+
+ List<Cluster> activeClusters = new ArrayList<>();
+ for (int i = 0; i < numDataPoints; i++) {
+ List<Integer> dataPointIds = new ArrayList<>();
+ dataPointIds.add(i);
+ activeClusters.add(new Cluster(i, dataPointIds));
+ }
+
+ // Precomputes vector norms for faster computation.
+ vectorWithNorms = new VectorWithNorm[inputList.size()];
+ for (int i = 0; i < numDataPoints; i++) {
+ vectorWithNorms[i] =
+ new VectorWithNorm((Vector)
inputList.get(i).getField(featuresCol));
+ }
+
+ // Clustering process.
+ doClustering(activeClusters);
+
+ // Remaps the cluster Ids and output results.
+ HashMap<Integer, Integer> remappedClusterIds = new HashMap<>();
+ int cnt = 0;
+ for (int i = 0; i < clusterIds.length; i++) {
+ int clusterId = clusterIds[i];
+ if (remappedClusterIds.containsKey(clusterId)) {
+ clusterIds[i] = remappedClusterIds.get(clusterId);
+ } else {
+ clusterIds[i] = cnt;
+ remappedClusterIds.put(clusterId, cnt++);
+ }
+ }
+
+ for (int i = 0; i < numDataPoints; i++) {
+ output.collect(
+ new StreamRecord<>(Row.join(inputList.get(i),
Row.of(clusterIds[i]))));
+ }
+ }
+
+ private int getNextClusterId() {
+ return nextClusterId++;
+ }
+
+ private void doClustering(List<Cluster> activeClusters) {
+ int clusterOffset1 = -1, clusterOffset2 = -1;
+ boolean clusteringRunning =
+ (k != null && activeClusters.size() > k) ||
(distanceThreshold != null);
+
+ while (clusteringRunning || (computeFullTree &&
activeClusters.size() > 1)) {
+ // Computes the distance between two clusters.
+ double minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < activeClusters.size(); i++) {
+ for (int j = i + 1; j < activeClusters.size(); j++) {
+ double distance =
+ computeDistanceBetweenClusters(
+ activeClusters.get(i),
activeClusters.get(j));
+ if (distance < minDistance) {
+ minDistance = distance;
+ clusterOffset1 = i;
+ clusterOffset2 = j;
+ }
+ }
+ }
+
+ // Outputs the merge info.
+ Cluster cluster1 = activeClusters.get(clusterOffset1);
+ Cluster cluster2 = activeClusters.get(clusterOffset2);
+ int clusterId1 = cluster1.clusterId;
+ int clusterId2 = cluster2.clusterId;
+ output.collect(
+ mergeInfoOutputTag,
+ new StreamRecord<>(
+ Tuple4.of(
+ Math.min(clusterId1, clusterId2),
+ Math.max(clusterId1, clusterId2),
+ minDistance,
+ cluster1.dataPointIds.size()
+ +
cluster2.dataPointIds.size())));
+
+ // Merges these two clusters.
+ Cluster mergedCluster =
+ new Cluster(
+ getNextClusterId(), cluster1.dataPointIds,
cluster2.dataPointIds);
+ activeClusters.set(clusterOffset1, mergedCluster);
+ activeClusters.remove(clusterOffset2);
+
+ // Updates cluster Ids for each data point if clustering is
still running.
+ if (clusteringRunning) {
+ int mergedClusterId = mergedCluster.clusterId;
+ for (int dataPointId : mergedCluster.dataPointIds) {
+ clusterIds[dataPointId] = mergedClusterId;
+ }
+ }
+
+ clusteringRunning =
+ (k != null && activeClusters.size() > k)
+ || (distanceThreshold != null &&
distanceThreshold > minDistance);
+ }
+ }
+
+ private double computeDistanceBetweenClusters(Cluster cluster1,
Cluster cluster2) {
+ double distance;
+ int size1 = cluster1.dataPointIds.size();
+ int size2 = cluster2.dataPointIds.size();
+
+ switch (linkage) {
+ case LINKAGE_AVERAGE:
+ distance = 0;
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ VectorWithNorm vectorWithNorm1 =
+
vectorWithNorms[cluster1.dataPointIds.get(i)];
+ VectorWithNorm vectorWithNorm2 =
+
vectorWithNorms[cluster2.dataPointIds.get(j)];
+ distance +=
distanceMeasure.distance(vectorWithNorm1, vectorWithNorm2);
+ }
+ }
+ distance /= size1 * size2;
+ break;
+ case LINKAGE_COMPLETE:
+ distance = Double.MIN_VALUE;
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ VectorWithNorm vectorWithNorm1 =
+
vectorWithNorms[cluster1.dataPointIds.get(i)];
+ VectorWithNorm vectorWithNorm2 =
+
vectorWithNorms[cluster2.dataPointIds.get(j)];
+ distance =
+ Math.max(
+ distance,
+ distanceMeasure.distance(
+ vectorWithNorm1,
vectorWithNorm2));
+ }
+ }
+ break;
+ case LINKAGE_SINGLE:
+ distance = Double.MAX_VALUE;
+ for (int i = 0; i < size1; i++) {
+ for (int j = 0; j < size2; j++) {
+ VectorWithNorm vectorWithNorm1 =
+
vectorWithNorms[cluster1.dataPointIds.get(i)];
+ VectorWithNorm vectorWithNorm2 =
+
vectorWithNorms[cluster2.dataPointIds.get(j)];
+ distance =
+ Math.min(
+ distance,
+ distanceMeasure.distance(
+ vectorWithNorm1,
vectorWithNorm2));
+ }
+ }
+ break;
+ case LINKAGE_WARD:
+ int vecSize = vectorWithNorms[0].vector.size();
+ DenseVector mean1 = Vectors.dense(new double[vecSize]);
+ DenseVector mean2 = Vectors.dense(new double[vecSize]);
+
+ for (int i = 0; i < size1; i++) {
+ BLAS.axpy(1.0,
vectorWithNorms[cluster1.dataPointIds.get(i)].vector, mean1);
+ }
+ for (int i = 0; i < size2; i++) {
+ BLAS.axpy(1.0,
vectorWithNorms[cluster2.dataPointIds.get(i)].vector, mean2);
+ }
+
+ DenseVector meanMerged = mean1.clone();
+ BLAS.axpy(1.0, mean2, meanMerged);
+ BLAS.scal(1.0 / size1, mean1);
+ BLAS.scal(1.0 / size2, mean2);
+ BLAS.scal(1.0 / (size1 + size2), meanMerged);
+ double essInc =
+ size1 * BLAS.dot(mean1, mean1)
+ + size2 * BLAS.dot(mean2, mean2)
+ - (size1 + size2) * BLAS.dot(meanMerged,
meanMerged);
+
+ distance = Math.sqrt(2 * essInc);
+ break;
+ default:
+ throw new UnsupportedOperationException(
+ "Unsupported " + LINKAGE + " type: " + linkage +
".");
+ }
+ return distance;
+ }
+
+ /** A cluster with cluster Id specified and data points that belong to
this cluster. */
+ private static class Cluster {
+ private final int clusterId;
+ private final List<Integer> dataPointIds;
+
+ @SafeVarargs
+ public Cluster(
+ int clusterId, List<Integer> dataPointIds,
List<Integer>... otherDataPointIds) {
Review Comment:
It might be better to write two constructors, `Cluster(int, List<Integer>)`
and `Cluster(int, List<Integer>, List<Integer>)`.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClusteringParams.java:
##########
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.ml.common.param.HasDistanceMeasure;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link AgglomerativeClustering}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface AgglomerativeClusteringParams<T>
+ extends HasDistanceMeasure<T>, HasFeaturesCol<T>, HasPredictionCol<T> {
+ Param<Integer> K = new IntParam("k", "The max number of clusters to
create.", 2);
Review Comment:
`k` has a clear meaning in Kmeans and Knn, but I'm not sure users of
Agglomerative Clustering algorithm. How about changing this parameter to
`numClusters` or `maxClusters`?
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up
approach. Each
+ * observation starts in its own cluster and the clusters are merged together
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer k = getK();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (k == null && distanceThreshold != null)
+ || (k != null && distanceThreshold == null),
+ "One of param k and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage
was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double,
Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getK(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance",
"sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static AgglomerativeClustering load(StreamTableEnvironment env,
String path)
Review Comment:
nit: `StreamTableEnvironment tEnv`
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up
approach. Each
+ * observation starts in its own cluster and the clusters are merged together
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer k = getK();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (k == null && distanceThreshold != null)
+ || (k != null && distanceThreshold == null),
+ "One of param k and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage
was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double,
Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getK(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance",
"sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
Review Comment:
Let's describe in JavaDoc the column names of the `mergeInfoTable` and their
meanings.
Besides, do you think it would be better to add some method like `Table
getMergeInfo()`, instead of making the merge info table the second output
table? I'm not sure the current API corresponds to the syntax of Flink ML
AlgoOperators. The current API also makes it difficult to connect
`AgglomerativeClustering` stage with other stages in a `Pipeline`.
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClusteringParams.java:
##########
@@ -0,0 +1,107 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.ml.common.param.HasDistanceMeasure;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of {@link AgglomerativeClustering}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface AgglomerativeClusteringParams<T>
+ extends HasDistanceMeasure<T>, HasFeaturesCol<T>, HasPredictionCol<T> {
+ Param<Integer> K = new IntParam("k", "The max number of clusters to
create.", 2);
+
+ Param<Double> DISTANCE_THRESHOLD =
+ new DoubleParam(
+ "distanceThreshold",
+ "Threshold to decide whether two clusters should be
merged.",
+ null);
+
+ String LINKAGE_WARD = "ward";
+ String LINKAGE_COMPLETE = "complete";
+ String LINKAGE_SINGLE = "single";
+ String LINKAGE_AVERAGE = "average";
+ /**
+ * Supported options to compute the distance between two clusters.
+ *
+ * <ul>
+ * <li>ward: difference between the sum of the variance of the two
clusters and the merged
+ * one.
+ * <li>complete: the maximum distance between all observations of the
two clusters.
+ * <li>single: the minimum distance between all observations of the two
cluster.
+ * <li>average: the average of the distance of all observations of the
two cluster.
+ * </ul>
+ */
+ Param<String> LINKAGE =
+ new StringParam(
+ "linkage",
+ "Criterion for computing distance between two clusters.",
+ LINKAGE_WARD,
+ ParamValidators.inArray(
+ LINKAGE_WARD, LINKAGE_COMPLETE, LINKAGE_AVERAGE,
LINKAGE_SINGLE));
+
+ Param<Boolean> COMPUTE_FULL_TREE =
+ new BooleanParam(
+ "computeFullTree",
+ "Whether computes the full tree after convergence.",
+ false,
Review Comment:
The default behavior here is different from that in sklearn. Do you think we
should also support the `auto` option here?
##########
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/agglomerativeclustering/AgglomerativeClustering.java:
##########
@@ -0,0 +1,415 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.clustering.agglomerativeclustering;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.distance.DistanceMeasure;
+import org.apache.flink.ml.common.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.VectorWithNorm;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An AlgoOperator that performs a hierarchical clustering using a bottom up
approach. Each
+ * observation starts in its own cluster and the clusters are merged together
one by one. Users can
+ * choose different strategies to merge two clusters by setting {@link
+ * AgglomerativeClusteringParams#LINKAGE} and different distance measure by
setting {@link
+ * AgglomerativeClusteringParams#DISTANCE_MEASURE}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Hierarchical_clustering.
+ */
+public class AgglomerativeClustering
+ implements AlgoOperator<AgglomerativeClustering>,
+ AgglomerativeClusteringParams<AgglomerativeClustering> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public AgglomerativeClustering() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ Integer k = getK();
+ Double distanceThreshold = getDistanceThreshold();
+ Preconditions.checkArgument(
+ (k == null && distanceThreshold != null)
+ || (k != null && distanceThreshold == null),
+ "One of param k and distanceThreshold should be null.");
+
+ if (getLinkage().equals(LINKAGE_WARD)) {
+ String distanceMeasure = getDistanceMeasure();
+ Preconditions.checkArgument(
+ distanceMeasure.equals(EuclideanDistanceMeasure.NAME),
+ distanceMeasure
+ + " was provided as distance measure while linkage
was ward. Ward only works with euclidean.");
+ }
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
Types.INT),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag =
+ new OutputTag<Tuple4<Integer, Integer, Double,
Integer>>("MERGE_INFO") {};
+
+ SingleOutputStreamOperator<Row> output =
+ dataStream.transform(
+ "doLocalAgglomerativeClustering",
+ outputTypeInfo,
+ new LocalAgglomerativeClusteringOperator(
+ getFeaturesCol(),
+ getLinkage(),
+ getDistanceMeasure(),
+ getK(),
+ getDistanceThreshold(),
+ getComputeFullTree(),
+ mergeInfoOutputTag));
+ output.getTransformation().setParallelism(1);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ DataStream<Tuple4<Integer, Integer, Double, Integer>> mergeInfo =
+ output.getSideOutput(mergeInfoOutputTag);
+ Table mergeInfoTable =
+ tEnv.fromDataStream(mergeInfo)
+ .as("clusterId1", "clusterId2", "distance",
"sizeOfMergedCluster");
+
+ return new Table[] {outputTable, mergeInfoTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static AgglomerativeClustering load(StreamTableEnvironment env,
String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ private static class LocalAgglomerativeClusteringOperator extends
AbstractStreamOperator<Row>
+ implements OneInputStreamOperator<Row, Row>, BoundedOneInput {
+ private final String featuresCol;
+ private final String linkage;
+ private final DistanceMeasure distanceMeasure;
+ private final Integer k;
+ private final Double distanceThreshold;
+ private final boolean computeFullTree;
+ private final OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag;
+
+ /** Cache for the input data. */
+ private List<Row> inputList;
+ /** State for the input data. */
+ private ListState<Row> inputListState;
+ /** Cluster id of each data point in inputList. */
+ private int[] clusterIds;
+ /** Precomputes the norm of each vector for performance. */
+ private VectorWithNorm[] vectorWithNorms;
+ /** Next cluster Id to be assigned. */
+ private int nextClusterId = 0;
+
+ public LocalAgglomerativeClusteringOperator(
+ String featuresCol,
+ String linkage,
+ String distanceMeasureName,
+ Integer k,
+ Double distanceThreshold,
+ boolean computeFullTree,
+ OutputTag<Tuple4<Integer, Integer, Double, Integer>>
mergeInfoOutputTag) {
+ this.featuresCol = featuresCol;
+ this.linkage = linkage;
+ this.k = k;
+ this.distanceThreshold = distanceThreshold;
+ this.computeFullTree = computeFullTree;
+ this.mergeInfoOutputTag = mergeInfoOutputTag;
+
+ distanceMeasure = DistanceMeasure.getInstance(distanceMeasureName);
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ inputListState =
+ context.getOperatorStateStore()
+ .getListState(new
ListStateDescriptor<>("inputListState", Row.class));
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> input) throws Exception {
+ inputListState.add(input.getValue());
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void endInput() throws Exception {
+ inputList = IteratorUtils.toList(inputListState.get().iterator());
+ int numDataPoints = inputList.size();
+
+ // Assigns initial cluster Ids.
+ clusterIds = new int[numDataPoints];
+ for (int i = 0; i < numDataPoints; i++) {
+ clusterIds[i] = getNextClusterId();
+ }
+
+ List<Cluster> activeClusters = new ArrayList<>();
+ for (int i = 0; i < numDataPoints; i++) {
+ List<Integer> dataPointIds = new ArrayList<>();
+ dataPointIds.add(i);
+ activeClusters.add(new Cluster(i, dataPointIds));
+ }
+
+ // Precomputes vector norms for faster computation.
+ vectorWithNorms = new VectorWithNorm[inputList.size()];
+ for (int i = 0; i < numDataPoints; i++) {
+ vectorWithNorms[i] =
+ new VectorWithNorm((Vector)
inputList.get(i).getField(featuresCol));
+ }
+
+ // Clustering process.
+ doClustering(activeClusters);
+
+ // Remaps the cluster Ids and output results.
+ HashMap<Integer, Integer> remappedClusterIds = new HashMap<>();
+ int cnt = 0;
+ for (int i = 0; i < clusterIds.length; i++) {
+ int clusterId = clusterIds[i];
+ if (remappedClusterIds.containsKey(clusterId)) {
+ clusterIds[i] = remappedClusterIds.get(clusterId);
+ } else {
+ clusterIds[i] = cnt;
+ remappedClusterIds.put(clusterId, cnt++);
+ }
+ }
+
+ for (int i = 0; i < numDataPoints; i++) {
+ output.collect(
+ new StreamRecord<>(Row.join(inputList.get(i),
Row.of(clusterIds[i]))));
+ }
+ }
+
+ private int getNextClusterId() {
+ return nextClusterId++;
+ }
+
+ private void doClustering(List<Cluster> activeClusters) {
+ int clusterOffset1 = -1, clusterOffset2 = -1;
+ boolean clusteringRunning =
+ (k != null && activeClusters.size() > k) ||
(distanceThreshold != null);
+
+ while (clusteringRunning || (computeFullTree &&
activeClusters.size() > 1)) {
+ // Computes the distance between two clusters.
+ double minDistance = Double.MAX_VALUE;
+ for (int i = 0; i < activeClusters.size(); i++) {
+ for (int j = i + 1; j < activeClusters.size(); j++) {
+ double distance =
+ computeDistanceBetweenClusters(
+ activeClusters.get(i),
activeClusters.get(j));
+ if (distance < minDistance) {
+ minDistance = distance;
+ clusterOffset1 = i;
+ clusterOffset2 = j;
+ }
+ }
+ }
+
+ // Outputs the merge info.
+ Cluster cluster1 = activeClusters.get(clusterOffset1);
+ Cluster cluster2 = activeClusters.get(clusterOffset2);
+ int clusterId1 = cluster1.clusterId;
+ int clusterId2 = cluster2.clusterId;
+ output.collect(
+ mergeInfoOutputTag,
+ new StreamRecord<>(
+ Tuple4.of(
+ Math.min(clusterId1, clusterId2),
+ Math.max(clusterId1, clusterId2),
+ minDistance,
+ cluster1.dataPointIds.size()
+ +
cluster2.dataPointIds.size())));
+
+ // Merges these two clusters.
+ Cluster mergedCluster =
+ new Cluster(
+ getNextClusterId(), cluster1.dataPointIds,
cluster2.dataPointIds);
+ activeClusters.set(clusterOffset1, mergedCluster);
+ activeClusters.remove(clusterOffset2);
+
+ // Updates cluster Ids for each data point if clustering is
still running.
+ if (clusteringRunning) {
+ int mergedClusterId = mergedCluster.clusterId;
+ for (int dataPointId : mergedCluster.dataPointIds) {
+ clusterIds[dataPointId] = mergedClusterId;
+ }
+ }
+
+ clusteringRunning =
+ (k != null && activeClusters.size() > k)
+ || (distanceThreshold != null &&
distanceThreshold > minDistance);
+ }
+ }
+
+ private double computeDistanceBetweenClusters(Cluster cluster1,
Cluster cluster2) {
Review Comment:
How about adding a cache here for the distance between any two data points?
We can avoid re-computing the distance if there is a cache hit.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]