lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r766351212



##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,74 @@
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory;
+
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+
+/**
+ * Column-major dense matrix. The entry values are stored in a single array of 
doubles with columns
+ * listed in sequence.
+ */
+@TypeInfo(DenseMatrixTypeInfoFactory.class)
+public class DenseMatrix implements Matrix {
+
+    /** Row dimension. */
+    public final int numRows;

Review comment:
       I previously forgot that we have `numRows()` and `numCols()` as API of 
this class.
   
   Since we already have `numRows()` and `numCols()`, could we change `numRows` 
and `numCols` to be private variables?

##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -22,13 +22,86 @@
 
 /** A utility class that provides BLAS routines over matrices and vectors. */
 public class BLAS {
+
     /** For level-1 function dspmv, use javaBLAS for better performance. */
     private static final dev.ludovic.netlib.BLAS JAVA_BLAS =
             dev.ludovic.netlib.JavaBLAS.getInstance();
 
-    /** y += a * x . */
+    /**
+     * \sum_i |x_i| .
+     *
+     * @param x x
+     * @return \sum_i |x_i|
+     */
+    public static double asum(DenseVector x) {
+        return JAVA_BLAS.dasum(x.size(), x.values, 0, 1);
+    }
+
+    /**
+     * y += a * x .
+     *
+     * @param a a
+     * @param x x
+     * @param y y
+     */
     public static void axpy(double a, DenseVector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
         JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
     }
+
+    /**
+     * x \cdot y .
+     *
+     * @param x x
+     * @param y y
+     * @return x \cdot y
+     */
+    public static double dot(DenseVector x, DenseVector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
+        return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
+    }
+
+    /**
+     * \sqrt(\sum_i x_i * x_i) .
+     *
+     * @param x x
+     * @return \sqrt(\sum_i x_i * x_i)
+     */
+    public static double norm2(DenseVector x) {
+        return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
+    }
+
+    /**
+     * x = x * a .
+     *
+     * @param a a
+     * @param x x
+     */
+    public static void scal(double a, DenseVector x) {
+        JAVA_BLAS.dscal(x.size(), a, x.values, 1);
+    }
+
+    /** */
+
+    /**
+     * y := alpha * A * x + beta * y.

Review comment:
       It seems that the Java doc is inconsistent with the function signature. 
Could you help improve it?
   
   And could you remove the above redundant `/** */`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),

Review comment:
       nits: the order of fields in LogisticRegressionTest is `features, label, 
weight`. Could we make the order consistent and also use `features, label` here?
   
   If we agree to enforce consist ordering, could we also update `testArray` 
below as well as the output of `KnnModel`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(2.0, Vectors.dense(200.1, 300.1)),
+                            Row.of(2.0, Vectors.dense(200.2, 300.2)),
+                            Row.of(2.0, Vectors.dense(200.3, 300.3)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.6, 300.6)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(3.0, Vectors.dense(2.8, 3.2)),
+                            Row.of(4.0, Vectors.dense(300., 3.2)),
+                            Row.of(1.0, Vectors.dense(2.2, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.4, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1))));
+
+    private static final List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5.0, Vectors.dense(4.0, 4.1)),
+                            Row.of(2.0, Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label", 
"features");
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as("label", 
"features");
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<Double, Double>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) row.getField("label"),
+                                                (Double) 
row.getField("prediction"));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<Double, Double>> 
result) {
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(10L, knn.getK().longValue());
+        assertEquals("prediction", knn.getPredictionCol());
+
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(trainData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testFewerDistinctPointsThanCluster() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(testData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+        executeAndCollect(output);
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel knnModel = knn.fit(trainData);
+        Table output = knnModel.transform(testData)[0];
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        knn.save(path);
+
+        Knn loadKnn = Knn.load(env, path);

Review comment:
       nits: Could we rename `loadKnn` as `loadedKnn`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(2.0, Vectors.dense(200.1, 300.1)),
+                            Row.of(2.0, Vectors.dense(200.2, 300.2)),
+                            Row.of(2.0, Vectors.dense(200.3, 300.3)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.6, 300.6)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(3.0, Vectors.dense(2.8, 3.2)),
+                            Row.of(4.0, Vectors.dense(300., 3.2)),
+                            Row.of(1.0, Vectors.dense(2.2, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.4, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1))));
+
+    private static final List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5.0, Vectors.dense(4.0, 4.1)),
+                            Row.of(2.0, Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label", 
"features");
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as("label", 
"features");
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<Double, Double>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) row.getField("label"),
+                                                (Double) 
row.getField("prediction"));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<Double, Double>> 
result) {
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(10L, knn.getK().longValue());
+        assertEquals("prediction", knn.getPredictionCol());
+
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")

Review comment:
       Could we remove the `setLabelCol(...)`, `setFeaturesCol(...)` and 
`setPredictionCol(...)` for simplicity, given that the functionality of these 
setXXX(...) is already tested by `testParam()` and 
`testFeaturePredictionParam()`?
   
   Same for other tests in this class.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =

Review comment:
       Could we use `Preconditions.checkArgument` to verify the number of 
inputs, similar to `LogisticRegressionModel::transform(...)`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);

Review comment:
       nits: could we rename this variable as `modelData` to avoid confusing 
this with the `knnModel`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);

Review comment:
       nits: it is not clear why this variable is `inoutData`. Could this be 
renamed to e.g. `input` for consistency the `input` and `inputList` used in 
this function?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;

Review comment:
       nits: Could we just use `k` here for consistent with the parameter k?
   
   It is probably better to use same variables for the same thing for better 
code readability.
   
   And could make the initialization order consistent with the order of 
constructor parameters?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (firstElement) {
+                comparator = Comparator.comparingDouble(o -> -o.f0);
+                
loadModel(getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstElement = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Double>, List<Double>> t2 = findNeighbor(vector, topN);
+            return Row.join(row, Row.of(getResult(t2)));
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input Input vector.
+         * @param topN Top N.
+         * @return Neighbors.
+         */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input, Integer topN) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue = new 
PriorityQueue<>(comparator);
+            search(input, topN, priorityQueue);
+            List<Double> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Double> result = priorityQueue.poll();
+                items.add(result.f1);
+                metrics.add(result.f0);
+            }
+            Collections.reverse(items);
+            Collections.reverse(metrics);
+            priorityQueue.clear();

Review comment:
       The queue is guaranteed to be empty as this point, right? Do we still 
need to explicitly clear it?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (firstElement) {
+                comparator = Comparator.comparingDouble(o -> -o.f0);
+                
loadModel(getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstElement = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Double>, List<Double>> t2 = findNeighbor(vector, topN);
+            return Row.join(row, Row.of(getResult(t2)));
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input Input vector.
+         * @param topN Top N.
+         * @return Neighbors.
+         */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input, Integer topN) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue = new 
PriorityQueue<>(comparator);
+            search(input, topN, priorityQueue);
+            List<Double> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Double> result = priorityQueue.poll();
+                items.add(result.f1);
+                metrics.add(result.f0);
+            }
+            Collections.reverse(items);
+            Collections.reverse(metrics);
+            priorityQueue.clear();
+            return Tuple2.of(items, metrics);
+        }
+
+        /**
+         * @param input Input vector.
+         * @param topN Top N.
+         * @param priorityQueue Priority queue.
+         */
+        private void search(
+                DenseVector input, int topN, PriorityQueue<Tuple2<Double, 
Double>> priorityQueue) {
+            double d = 0.0;
+            for (int i = 0; i < input.size(); ++i) {
+                d += input.values[i] * input.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(input, d);
+            Tuple2<Double, Double> head = null;
+
+            List<Tuple2<Double, Double>> values = computeDistance(sample);
+            for (Tuple2<Double, Double> currentValue : values) {
+                head = updateQueue(priorityQueue, topN, currentValue, head);
+            }
+        }
+
+        /**
+         * Updates queue.
+         *
+         * @param pq Queue.
+         * @param topN Top N.
+         * @param newValue New value.
+         * @param head Head value.
+         * @return Head value.
+         */
+        private <T> Tuple2<Double, T> updateQueue(
+                PriorityQueue<Tuple2<Double, T>> pq,
+                int topN,
+                Tuple2<Double, T> newValue,
+                Tuple2<Double, T> head) {
+            if (pq.size() < topN) {
+                pq.add(Tuple2.of(newValue.f0, newValue.f1));
+                head = pq.peek();
+            } else {
+                if (pq.comparator().compare(head, newValue) < 0) {
+                    Tuple2<Double, T> peek = pq.poll();
+                    assert peek != null;

Review comment:
       Could we avoid using `assert` in production code? `assert` is not 
enabled by default in production code.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(2.0, Vectors.dense(200.1, 300.1)),
+                            Row.of(2.0, Vectors.dense(200.2, 300.2)),
+                            Row.of(2.0, Vectors.dense(200.3, 300.3)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.6, 300.6)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(3.0, Vectors.dense(2.8, 3.2)),
+                            Row.of(4.0, Vectors.dense(300., 3.2)),
+                            Row.of(1.0, Vectors.dense(2.2, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.4, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1))));
+
+    private static final List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5.0, Vectors.dense(4.0, 4.1)),
+                            Row.of(2.0, Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label", 
"features");
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as("label", 
"features");
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<Double, Double>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) row.getField("label"),
+                                                (Double) 
row.getField("prediction"));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<Double, Double>> 
result) {
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(10L, knn.getK().longValue());
+        assertEquals("prediction", knn.getPredictionCol());
+
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(trainData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testFewerDistinctPointsThanCluster() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(testData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());

Review comment:
       This seems to be already tested in `testFeaturePredictionParam()`. Could 
we remove this for simplicity?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(2.0, Vectors.dense(200.1, 300.1)),
+                            Row.of(2.0, Vectors.dense(200.2, 300.2)),
+                            Row.of(2.0, Vectors.dense(200.3, 300.3)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.6, 300.6)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(3.0, Vectors.dense(2.8, 3.2)),
+                            Row.of(4.0, Vectors.dense(300., 3.2)),
+                            Row.of(1.0, Vectors.dense(2.2, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.4, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1))));
+
+    private static final List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5.0, Vectors.dense(4.0, 4.1)),
+                            Row.of(2.0, Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label", 
"features");
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as("label", 
"features");
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<Double, Double>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) row.getField("label"),
+                                                (Double) 
row.getField("prediction"));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<Double, Double>> 
result) {
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(10L, knn.getK().longValue());
+        assertEquals("prediction", knn.getPredictionCol());
+
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(trainData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testFewerDistinctPointsThanCluster() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(testData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+        executeAndCollect(output);
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel knnModel = knn.fit(trainData);
+        Table output = knnModel.transform(testData)[0];
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        knn.save(path);
+
+        Knn loadKnn = Knn.load(env, path);
+        KnnModel knnModel = loadKnn.fit(trainData);
+        Table output = knnModel.transform(testData)[0];
+
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testModelSaveLoadAndPredict() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel knnModel = knn.fit(trainData);
+        knnModel.save(path);
+        env.execute();
+
+        KnnModel newModel = KnnModel.load(env, path);
+        Table output = newModel.transform(testData)[0];
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table modelData = knnModel.getModelData()[0];
+
+        DataStream<Row> output = tEnv.toDataStream(modelData);
+        assertEquals("f0", 
modelData.getResolvedSchema().getColumnNames().get(0));
+
+        List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+
+        KnnModelData data = (KnnModelData) modelRows.get(0).getField(0);
+
+        assert data != null;

Review comment:
       nits: Could we use `Assert.assertNotNull(...)` for consistency. with 
`Assert.assertEquals(...)` used below?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(2.0, Vectors.dense(200.1, 300.1)),
+                            Row.of(2.0, Vectors.dense(200.2, 300.2)),
+                            Row.of(2.0, Vectors.dense(200.3, 300.3)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.6, 300.6)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(3.0, Vectors.dense(2.8, 3.2)),
+                            Row.of(4.0, Vectors.dense(300., 3.2)),
+                            Row.of(1.0, Vectors.dense(2.2, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.4, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1))));
+
+    private static final List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5.0, Vectors.dense(4.0, 4.1)),
+                            Row.of(2.0, Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label", 
"features");
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as("label", 
"features");
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<Double, Double>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) row.getField("label"),
+                                                (Double) 
row.getField("prediction"));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<Double, Double>> 
result) {
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(10L, knn.getK().longValue());
+        assertEquals("prediction", knn.getPredictionCol());
+
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {

Review comment:
       The purpose of `testFeaturePredictionParam` in `LogisticRegressionTest` 
and `NaiveBayesTest` is to verify that the output schema would be consistent 
with the column names specified by users. So we need to use non-default column 
names.
   
   Could you update the test to use non-default column names, similar to 
`LogisticRegressionTest::testFeaturePredictionParam`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests Knn and KnnModel. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1.0, Vectors.dense(2.0, 3.0)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(2.0, Vectors.dense(200.1, 300.1)),
+                            Row.of(2.0, Vectors.dense(200.2, 300.2)),
+                            Row.of(2.0, Vectors.dense(200.3, 300.3)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.4, 300.4)),
+                            Row.of(2.0, Vectors.dense(200.6, 300.6)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.3, 3.2)),
+                            Row.of(3.0, Vectors.dense(2.8, 3.2)),
+                            Row.of(4.0, Vectors.dense(300., 3.2)),
+                            Row.of(1.0, Vectors.dense(2.2, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.4, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(5.0, Vectors.dense(2.5, 3.2)),
+                            Row.of(1.0, Vectors.dense(2.1, 3.1))));
+
+    private static final List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5.0, Vectors.dense(4.0, 4.1)),
+                            Row.of(2.0, Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.DOUBLE())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label", 
"features");
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as("label", 
"features");
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<Double, Double>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) row.getField("label"),
+                                                (Double) 
row.getField("prediction"));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<Double, Double>> 
result) {
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(10L, knn.getK().longValue());
+        assertEquals("prediction", knn.getPredictionCol());
+
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(trainData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testFewerDistinctPointsThanCluster() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel model = knn.fit(testData);
+        Table output = model.transform(testData)[0];
+
+        assertEquals(
+                Arrays.asList("label", "features", "prediction"),
+                output.getResolvedSchema().getColumnNames());
+        executeAndCollect(output);
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel knnModel = knn.fit(trainData);
+        Table output = knnModel.transform(testData)[0];
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        knn.save(path);
+
+        Knn loadKnn = Knn.load(env, path);
+        KnnModel knnModel = loadKnn.fit(trainData);
+        Table output = knnModel.transform(testData)[0];
+
+        List<Tuple2<Double, Double>> result = executeAndCollect(output);
+        verifyClusteringResult(result);
+    }
+
+    @Test
+    public void testModelSaveLoadAndPredict() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("features")
+                        .setK(4)
+                        .setPredictionCol("prediction");
+        KnnModel knnModel = knn.fit(trainData);
+        knnModel.save(path);

Review comment:
       Could we update the code to use `StageTestUtils.saveAndReload(...)`? 
   
   This util method is provided in the NaiveBayes PR. The main benefit of this 
approach is to use `org.junit.rules.TemporaryFolder` to make sure the temporary 
directory will be deleted after the test finishes.
   
   And do you think it would be a bit better to merge 
`testModelSaveLoadAndPredict` into `testSaveLoadAndPredict`, similar to 
`NaiveBayesTest::testSaveLoadAndPredict()`? Ideally we could find out the best 
practice for testing similar functionalities across algorithms so that our 
algorithm test code will be neat and consistent.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (firstElement) {

Review comment:
       Instead of using `firstElement`, could we just check `knnModelData == 
null` here?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;

Review comment:
       `NaiveBayesModel` and `LogisticRegressionModel` do not use `transient` 
for variables (obtained from the broadcast input). Could you explain why we 
need to use `transient` here? If there is good reason, should we do the same 
for NaiveBayes and LogisticRegression?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (firstElement) {
+                comparator = Comparator.comparingDouble(o -> -o.f0);
+                
loadModel(getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstElement = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Double>, List<Double>> t2 = findNeighbor(vector, topN);
+            return Row.join(row, Row.of(getResult(t2)));
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input Input vector.
+         * @param topN Top N.
+         * @return Neighbors.
+         */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input, Integer topN) {

Review comment:
       Since we already have `topN` as class member variable, do we still need 
to pass this as function parameter?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastModelKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (firstElement) {
+                comparator = Comparator.comparingDouble(o -> -o.f0);
+                
loadModel(getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstElement = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Double>, List<Double>> t2 = findNeighbor(vector, topN);
+            return Row.join(row, Row.of(getResult(t2)));
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input Input vector.
+         * @param topN Top N.
+         * @return Neighbors.
+         */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input, Integer topN) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue = new 
PriorityQueue<>(comparator);
+            search(input, topN, priorityQueue);
+            List<Double> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Double> result = priorityQueue.poll();
+                items.add(result.f1);
+                metrics.add(result.f0);
+            }
+            return Tuple2.of(items, metrics);
+        }
+
+        /**
+         * @param input Input vector.

Review comment:
       Would we have more information in the Java doc regarding the 
functionality/purpose of these this method?
   
   Same for `updateQueue` and other methods.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,317 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> model = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.map(
+                                    new PredictLabelFunction(
+                                            broadcastKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private boolean firstElement = true;
+        private final String featureCol;
+        private transient KnnModelData knnModelData;
+        private final int topN;
+        private final String broadcastKey;
+        private transient Comparator<? super Tuple2<Double, Double>> 
comparator;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (firstElement) {
+                comparator = Comparator.comparingDouble(o -> -o.f0);
+                
loadModel(getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstElement = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Double>, List<Double>> t2 = findNeighbor(vector, topN);
+            return Row.join(row, Row.of(getResult(t2)));
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input Input vector.
+         * @param topN Top N.
+         * @return Neighbors.
+         */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input, Integer topN) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue = new 
PriorityQueue<>(comparator);

Review comment:
       Since we are OK to instantiate `priorityQueue` every time the `map` is 
invoked, could we also instantiate `comparator` here as well, for better code 
simplicity and consistency?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to