zhipeng93 commented on code in PR #191:
URL: https://github.com/apache/flink-ml/pull/191#discussion_r1064449393


##########
docs/content/docs/operators/feature/minhashlsh.md:
##########
@@ -0,0 +1,287 @@
+---
+title: "MinHash LSH"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/minhashlsh.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## MinHash LSH
+
+MinHash LSH is a Locality Sensitive Hashing (LSH) scheme for Jaccard distance 
metric.
+The input features are sets of natural numbers represented as non-zero indices 
of vectors,
+either dense vectors or sparse vectors. Typically, sparse vectors are more 
efficient.
+
+In addition to transforming input feature vectors to multiple hash values, the 
MinHash LSH 
+model also supports approximate nearest neighbors search within a dataset 
regarding a key 
+vector and approximate similarity join between two datasets.
+
+### Input Columns
+
+| Param name | Type   | Default   | Description            |
+|:-----------|:-------|:----------|:-----------------------|
+| inputCol   | Vector | `"input"` | Features to be mapped. |
+
+### Output Columns
+
+| Param name | Type          | Default    | Description  |
+|:-----------|:--------------|:-----------|:-------------|
+| outputCol  | DenseVector[] | `"output"` | Hash values. |
+
+### Parameters
+
+Below are the parameters required by `MinHashLSHModel`.
+
+| Key                     | Default    | Type    | Required | Description      
                                                  |
+|-------------------------|------------|---------|----------|--------------------------------------------------------------------|
+| inputCol                | `"input"`  | String  | no       | Input column 
name.                                                 |
+| outputCol               | `"output"` | String  | no       | Output column 
name.                                                |
+
+`MinHashLSH` needs parameters above and also below.
+
+| Key                     | Default    | Type    | Required | Description      
                                                  |
+|-------------------------|------------|---------|----------|--------------------------------------------------------------------|
+| seed                    | `null`     | Long    | no       | The random seed. 
                                                  |
+| numHashTables           | `1`        | Integer | no       | Default number 
of hash tables, for OR-amplification.               |
+| numHashFunctionPerTable | `1`        | Integer | no       | Default number 
of hash functions per table, for AND-amplification. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.ml.feature.lsh.MinHashLSH;
+import org.apache.flink.ml.feature.lsh.MinHashLSHModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+
+/**
+ * Simple program that trains a MinHashLSH model and uses it for approximate 
nearest neighbors and
+ * similarity join.
+ */
+public class MinHashLSHExample {
+    public static void main(String[] args) throws Exception {
+
+        // Creates a new StreamExecutionEnvironment

Review Comment:
   The java/python doc usually ends with a `.` since they are sentences. Could 
you update it?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHParams.java:
##########
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link LSH}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface LSHParams<T> extends LSHModelParams<T> {
+
+    /**
+     * Param for the number of hash tables used in LSH OR-amplification.
+     *
+     * <p>OR-amplification can be used to reduce the false negative rate. 
Higher values of this
+     * param lead to a reduced false negative rate, at the expense of added 
computational
+     * complexity.
+     */
+    Param<Integer> NUM_HASH_TABLES =
+            new IntParam("numHashTables", "Number of hash tables.", 1, 
ParamValidators.gtEq(1));
+
+    default int getNumHashTables() {

Review Comment:
   Let's declare the variables first and then functions, following the exising 
code style (e.g., RegexTokenizerParams and StandardScalerParams).



##########
flink-ml-python/pyflink/ml/lib/feature/tests/test_minhashlsh.py:
##########
@@ -0,0 +1,271 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import functools
+import os
+from typing import List
+
+from pyflink.common import Row, Types
+from pyflink.java_gateway import get_gateway
+from pyflink.ml.core.linalg import Vectors, SparseVectorTypeInfo, DenseVector
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.lsh import MinHashLSH, MinHashLSHModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+from pyflink.table import Table
+
+
+class MinHashLSHTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(MinHashLSHTest, self).setUp()
+        self.data = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (0, Vectors.sparse(6, [0, 1, 2], [1., 1., 1.])),
+                (1, Vectors.sparse(6, [2, 3, 4], [1., 1., 1.])),
+                (2, Vectors.sparse(6, [0, 2, 4], [1., 1., 1.])),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['id', 'vec'],
+                    [Types.INT(), SparseVectorTypeInfo()])))
+
+        self.expected = [
+            Row([
+                Vectors.dense(1.73046954E8, 1.57275425E8, 6.90717571E8),
+                Vectors.dense(5.02301169E8, 7.967141E8, 4.06089319E8),
+                Vectors.dense(2.83652171E8, 1.97714719E8, 6.04731316E8),
+                Vectors.dense(5.2181506E8, 6.36933726E8, 6.13894128E8),
+                Vectors.dense(3.04301769E8, 1.113672955E9, 6.1388711E8),
+            ]),
+            Row([
+                Vectors.dense(1.73046954E8, 1.57275425E8, 6.7798584E7),
+                Vectors.dense(6.38582806E8, 1.78703694E8, 4.06089319E8),
+                Vectors.dense(6.232638E8, 9.28867E7, 9.92010642E8),
+                Vectors.dense(2.461064E8, 1.12787481E8, 1.92180297E8),
+                Vectors.dense(2.38162496E8, 1.552933319E9, 2.77995137E8),
+            ]),
+            Row([
+                Vectors.dense(1.73046954E8, 1.57275425E8, 6.90717571E8),
+                Vectors.dense(1.453197722E9, 7.967141E8, 4.06089319E8),
+                Vectors.dense(6.232638E8, 1.97714719E8, 6.04731316E8),
+                Vectors.dense(2.461064E8, 1.12787481E8, 1.92180297E8),
+                Vectors.dense(1.224130231E9, 1.113672955E9, 2.77995137E8),
+            ])]
+
+    def test_param(self):
+        lsh = MinHashLSH()
+        self.assertEqual('input', lsh.get_input_col())

Review Comment:
   Let's use `lsh.input_col()` here, otherwise lsh.input_col() may not be 
tested.
   
   Same for other python functions. 



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/LSHModel.java:
##########
@@ -0,0 +1,458 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/**
+ * Base class for LSH model.
+ *
+ * <p>In addition to transforming input feature vectors to multiple hash 
values, it also supports
+ * approximate nearest neighbors search within a dataset regarding a key 
vector and approximate
+ * similarity join between two datasets.
+ *
+ * @param <T> class type of the LSHModel implementation itself.
+ */
+abstract class LSHModel<T extends LSHModel<T>> implements Model<T>, 
LSHModelParams<T> {
+    private static final String MODEL_DATA_BC_KEY = "modelData";
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    /** Stores the corresponding model data class of T. */
+    private final Class<? extends LSHModelData> modelDataClass;
+
+    protected Table modelDataTable;
+
+    public LSHModel(Class<? extends LSHModelData> modelDataClass) {
+        this.modelDataClass = modelDataClass;
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public T setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return (T) this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<? extends LSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        TypeInformation<?> outputType = 
TypeInformation.of(DenseVector[].class);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
outputType),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        
Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) 
inputList.get(0);
+                            return data.map(new 
PredictFunction(getInputCol()), outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /**
+     * Approximately finds at most k items from a dataset which have the 
closest distance to a given
+     * item. If the `outputCol` is missing in the given dataset, this method 
transforms the dataset
+     * with the model at first.
+     *
+     * @param dataset The dataset in which to to search for nearest neighbors.
+     * @param key The item to search for.
+     * @param k The maximum number of nearest neighbors.
+     * @param distCol The output column storing the distance between each 
neighbor and the key.
+     * @return A dataset containing at most k items closest to the key with a 
column named `distCol`
+     *     appended.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k, 
String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
dataset).getTableEnvironment();
+        Table transformedTable =
+                
(dataset.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataset
+                        : transform(dataset)[0];
+
+        DataStream<? extends LSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, modelDataClass);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(transformedTable.getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.DOUBLE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
distCol));
+
+        // Fetch items in the same bucket with key's, and calculate their 
distances to key.
+        DataStream<Row> filteredData =
+                BroadcastUtils.withBroadcastStream(
+                        
Collections.singletonList(tEnv.toDataStream(transformedTable)),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            //noinspection unchecked
+                            DataStream<Row> data = (DataStream<Row>) 
inputList.get(0);
+                            return data.flatMap(
+                                    new FilterByBucketFunction(getInputCol(), 
getOutputCol(), key),
+                                    outputTypeInfo);
+                        });
+        DataStream<List<Row>> topKList =
+                DataStreamUtils.aggregate(filteredData, new 
TopKFunction(distCol, k));
+        DataStream<Row> topKData =
+                topKList.flatMap(
+                        (value, out) -> {
+                            for (Row row : value) {
+                                out.collect(row);
+                            }
+                        });
+        topKData.getTransformation().setOutputType(outputTypeInfo);
+        return tEnv.fromDataStream(topKData);
+    }
+
+    /**
+     * An overloaded version of `approxNearestNeighbors` with "distCol" as 
default value of
+     * `distCol`.
+     */
+    public Table approxNearestNeighbors(Table dataset, Vector key, int k) {
+        return approxNearestNeighbors(dataset, key, k, "distCol");
+    }
+
+    /**
+     * Joins two datasets to approximately find all pairs of rows whose 
distance are smaller than or
+     * equal to the threshold. If the `outputCol` is missing in either 
dataset, this method
+     * transforms the dataset at first.
+     *
+     * @param datasetA One dataset.
+     * @param datasetB The other dataset.
+     * @param threshold The distance threshold.
+     * @param idCol A column in the two datasets to identify each row.
+     * @param distCol The output column storing the distance between each pair 
of rows.
+     * @return A joined dataset containing pairs of rows. The original rows 
are in columns
+     *     "datasetA" and "datasetB", and a column "distCol" is added to show 
the distance between
+     *     each pair.
+     */
+    public Table approxSimilarityJoin(
+            Table datasetA, Table datasetB, double threshold, String idCol, 
String distCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
datasetA).getTableEnvironment();
+
+        DataStream<Row> explodedA = preprocessData(datasetA, idCol);
+        DataStream<Row> explodedB = preprocessData(datasetB, idCol);
+
+        DataStream<? extends LSHModelData> modelData =
+                tEnv.toDataStream(modelDataTable, modelDataClass);
+        DataStream<Row> sameBucketPairs =
+                explodedA
+                        .join(explodedB)
+                        .where(new IndexHashValueKeySelector())
+                        .equalTo(new IndexHashValueKeySelector())
+                        .window(EndOfStreamWindows.get())
+                        .apply(
+                                (r0, r1) ->
+                                        Row.of(
+                                                r0.getField(0),
+                                                r1.getField(0),
+                                                r0.getField(1),
+                                                r1.getField(1)));
+        DataStream<Row> distinctSameBucketPairs =
+                DataStreamUtils.reduce(
+                        sameBucketPairs.keyBy(
+                                new KeySelector<Row, Tuple2<Integer, 
Integer>>() {
+                                    @Override
+                                    public Tuple2<Integer, Integer> getKey(Row 
r) {
+                                        return Tuple2.of(r.getFieldAs(0), 
r.getFieldAs(1));
+                                    }
+                                }),
+                        (r0, r1) -> r0);
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(datasetA.getResolvedSchema());
+        TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol);
+        DataStream<Row> pairsWithDists =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(distinctSameBucketPairs),
+                        Collections.singletonMap(MODEL_DATA_BC_KEY, modelData),
+                        inputList -> {
+                            DataStream<Row> data = (DataStream<Row>) 
inputList.get(0);
+                            return data.flatMap(
+                                    new FilterByDistanceFunction(threshold),
+                                    new RowTypeInfo(
+                                            new TypeInformation[] {
+                                                idColType, idColType, 
Types.DOUBLE
+                                            },
+                                            new String[] {"datasetA.id", 
"datasetB.id", distCol}));
+                        });
+        return tEnv.fromDataStream(pairsWithDists);
+    }
+
+    /**
+     * An overloaded version of `approxNearestNeighbors` with "distCol" as 
default value of
+     * `distCol`.
+     */
+    public Table approxSimilarityJoin(
+            Table datasetA, Table datasetB, double threshold, String idCol) {
+        return approxSimilarityJoin(datasetA, datasetB, threshold, idCol, 
"distCol");
+    }
+
+    private DataStream<Row> preprocessData(Table dataTable, String idCol) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
dataTable).getTableEnvironment();
+
+        dataTable =
+                
(dataTable.getResolvedSchema().getColumnNames().contains(getOutputCol()))
+                        ? dataTable
+                        : transform(dataTable)[0];
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(dataTable.getResolvedSchema());
+        TypeInformation<?> idColType = inputTypeInfo.getTypeAt(idCol);
+        final String indexCol = "index";
+        final String hashValueCol = "hashValue";
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            idColType,
+                            TypeInformation.of(Vector.class),
+                            Types.INT,
+                            TypeInformation.of(DenseVector.class)
+                        },
+                        new String[] {idCol, getInputCol(), indexCol, 
hashValueCol});
+
+        return tEnv.toDataStream(dataTable)
+                .flatMap(
+                        new ExplodeHashValuesFunction(idCol, getInputCol(), 
getOutputCol()),
+                        outputTypeInfo);
+    }
+
+    private static class PredictFunction extends RichMapFunction<Row, Row> {
+        private final String inputCol;
+
+        private LSHModelData modelData;
+
+        public PredictFunction(String inputCol) {
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row value) throws Exception {
+            if (null == modelData) {
+                modelData =
+                        (LSHModelData)
+                                
getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0);
+            }
+            Vector[] hashValues = 
modelData.hashFunction(value.getFieldAs(inputCol));
+            return Row.join(value, Row.of((Object) hashValues));
+        }
+    }
+
+    private static class FilterByBucketFunction extends 
RichFlatMapFunction<Row, Row> {
+        private final String inputCol;
+        private final String outputCol;
+        private final Vector key;
+        private LSHModelData modelData;
+        private DenseVector[] keyHashes;
+
+        public FilterByBucketFunction(String inputCol, String outputCol, 
Vector key) {
+            this.inputCol = inputCol;
+            this.outputCol = outputCol;
+            this.key = key;
+        }
+
+        @Override
+        public void flatMap(Row value, Collector<Row> out) throws Exception {
+            if (null == modelData) {
+                modelData =
+                        (LSHModelData)
+                                
getRuntimeContext().getBroadcastVariable(MODEL_DATA_BC_KEY).get(0);
+                keyHashes = modelData.hashFunction(key);
+            }
+            DenseVector[] hashes = value.getFieldAs(outputCol);
+            boolean sameBucket = false;
+            for (int i = 0; i < keyHashes.length; i += 1) {
+                if (keyHashes[i].equals(hashes[i])) {
+                    sameBucket = true;
+                    break;
+                }
+            }
+            if (!sameBucket) {
+                return;
+            }
+            Vector vec = value.getFieldAs(inputCol);
+            double dist = modelData.keyDistance(key, vec);
+            out.collect(Row.join(value, Row.of(dist)));
+        }
+    }
+
+    private static class TopKFunction
+            implements AggregateFunction<Row, PriorityQueue<Row>, List<Row>> {
+        private final int numNearestNeighbors;
+        private final String distCol;
+
+        private static class DistColComparator implements Comparator<Row> {
+
+            private final String distCol;
+
+            private DistColComparator(String distCol) {
+                this.distCol = distCol;
+            }
+
+            @Override
+            public int compare(Row o1, Row o2) {
+                return Double.compare(o1.getFieldAs(distCol), 
o2.getFieldAs(distCol));
+            }
+        }
+
+        public TopKFunction(String distCol, int numNearestNeighbors) {
+            this.distCol = distCol;
+            this.numNearestNeighbors = numNearestNeighbors;
+        }
+
+        @Override
+        public PriorityQueue<Row> createAccumulator() {
+            return new PriorityQueue<>(numNearestNeighbors, new 
DistColComparator(distCol));
+        }
+
+        @Override
+        public PriorityQueue<Row> add(Row value, PriorityQueue<Row> 
accumulator) {
+            if (accumulator.size() == numNearestNeighbors) {
+                Row peek = accumulator.peek();
+                if (accumulator.comparator().compare(value, peek) < 0) {
+                    accumulator.poll();
+                }
+            }
+            accumulator.add(value);
+            return accumulator;
+        }
+
+        @Override
+        public List<Row> getResult(PriorityQueue<Row> accumulator) {
+            return new ArrayList<>(accumulator);
+        }
+
+        @Override
+        public PriorityQueue<Row> merge(PriorityQueue<Row> a, 
PriorityQueue<Row> b) {
+            PriorityQueue<Row> merged = new PriorityQueue<>(a);
+            for (Row row : b) {
+                merged = add(row, merged);

Review Comment:
   Is this assignment necessary?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/lsh/MinHashLSH.java:
##########
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.lsh;
+
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+
+import java.io.IOException;
+
+/**
+ * An Estimator that implements the MinHash LSH algorithm, which supports LSH 
for Jaccard distance.
+ *
+ * <p>The input could be dense or sparse vectors. Each input vector must have 
at least one non-zero
+ * index and all non-zero values are treated as binary "1" values. The sizes 
of input vectors should
+ * be same and not too large (not larger than a large prime 2038074743).

Review Comment:
   nit: ... and not larger than a predefined prime (i.e., 2038074743).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to