lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r767490529



##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,69 @@
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory;
+
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;

Review comment:
       Can we replace this with `org.apache.flink.util.Preconditions`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private Table validData;

Review comment:
       nits: it is not very clear what is `validData`. Would `predictData` be a 
bit more intuitive here?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##########
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link KnnModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
+ * to save/load model data.
+ */
+public class KnnModelData {

Review comment:
       @gaoyunhaii is this OK to use `KnnModelData` as the DataStream element 
type when `KnnModelData` is not serializable?
   
   If we use `KnnModelData` as DataStream element type, would Flink 
automatically use `DenseMatrixSerializer` and `DenseVectorSerializer` to 
serialize/de-serialize its member variables?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {

Review comment:
       Hmm.. why we do allow user to provide an input row without the 
`featureCol`? Would it be better to throw exception here?
   
   And could we rename this variable as `feature` to be consistent with the 
`featureCol`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,222 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);

Review comment:
       It seems that only the first element of this tuple2 is used. Could we 
simplify `findNeighbor(...)` to return just a `List<Double>`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {

Review comment:
       nits: could we rename `Double obj` to `Double label` to make the code a 
bit more self-explanatory?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        /* Tuple2 : <sampleVector, label> */
+        DataStream<Tuple2<DenseVector, Double>> inputData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<DenseVector, 
Double>>() {
+                                    @Override
+                                    public Tuple2<DenseVector, Double> map(Row 
value) {
+                                        Double label = (Double) 
value.getField(getLabelCol());
+                                        DenseVector feature =
+                                                (DenseVector) 
value.getField(getFeaturesCol());
+                                        return Tuple2.of(feature, label);
+                                    }
+                                });
+        DataStream<KnnModelData> distributedModelData = 
prepareModelData(inputData);
+        DataStream<KnnModelData> modelData = 
mergeModelData(distributedModelData);
+        KnnModel model = new 
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Knn load(StreamExecutionEnvironment env, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Prepares distributed knn model data. Constructs the sample matrix and 
computes norm of
+     * vectors.
+     *
+     * @param inputData Input vector data with label.
+     * @return Distributed knn model.
+     */
+    private static DataStream<KnnModelData> prepareModelData(
+            DataStream<Tuple2<DenseVector, Double>> inputData) {
+        return DataStreamUtils.mapPartition(
+                inputData,
+                new RichMapPartitionFunction<Tuple2<DenseVector, Double>, 
KnnModelData>() {
+                    @Override
+                    public void mapPartition(
+                            Iterable<Tuple2<DenseVector, Double>> values,
+                            Collector<KnnModelData> out) {
+                        List<Tuple2<DenseVector, Double>> dataPoints = new 
ArrayList<>(0);
+                        for (Tuple2<DenseVector, Double> tuple2 : values) {
+                            dataPoints.add(tuple2);
+                        }
+                        int featureDim = dataPoints.get(0).f0.size();
+                        DenseMatrix packedFeatures = new 
DenseMatrix(featureDim, dataPoints.size());
+                        DenseVector labels = new 
DenseVector(dataPoints.size());
+                        for (int i = 0; i < dataPoints.size(); ++i) {
+                            Tuple2<DenseVector, Double> tuple2 = 
dataPoints.get(i);
+                            labels.values[i] = tuple2.f1;
+                            double[] vectorData = tuple2.f0.values;
+                            double[] matrixData = packedFeatures.values;
+                            System.arraycopy(vectorData, 0, matrixData, i * 
featureDim, featureDim);
+                        }
+                        DenseVector featureNorms = computeNorm(packedFeatures);
+                        if (dataPoints.size() > 0) {
+                            out.collect(new KnnModelData(packedFeatures, 
featureNorms, labels));
+                        }
+                    }
+                },
+                TypeInformation.of(KnnModelData.class));
+    }
+
+    /**
+     * Merges knn model data.
+     *
+     * @param distributedModelData Distributed knn model data.
+     * @return Knn model.
+     */
+    private static DataStream<KnnModelData> mergeModelData(
+            DataStream<KnnModelData> distributedModelData) {
+        distributedModelData.getTransformation().setParallelism(1);
+        return DataStreamUtils.mapPartition(
+                distributedModelData,
+                new RichMapPartitionFunction<KnnModelData, KnnModelData>() {
+                    @Override
+                    public void mapPartition(
+                            Iterable<KnnModelData> values, 
Collector<KnnModelData> out) {
+                        List<KnnModelData> bufferKnnModelData = new 
ArrayList<>(1);

Review comment:
       nits: would it be slightly better to use `bufferedKnnModelData`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+            double max = 0.0;
+            double prediction = 0.0;
+            for (Map.Entry<Double, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+            return Row.join(row, Row.of(prediction));
+        }
+
+        /** Finds the nearest k neighbors from whole vectors in matrix. */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double d = 0.0;
+            for (int i = 0; i < input.size(); ++i) {
+                d += input.values[i] * input.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(input, d);
+            DenseMatrix packedFeatures = knnModelData.packedFeatures;
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new 
DenseVector(packedFeatures.numCols());
+            BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distances = new 
ArrayList<>(labelValues.length);
+            for (int i = 0; i < labelValues.length; i++) {
+                Tuple2<Double, Double> tuple2 = 
Tuple2.of(distanceVector.values[i], labelValues[i]);
+                distances.add(tuple2);
+            }
+            Tuple2<Double, Double> head = null;
+            for (Tuple2<Double, Double> currentValue : distances) {

Review comment:
       nits: could we rename `currentValue` to be something more meaningful and 
consistent with `distance`, e.g. `distanceAndLabel`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+            double max = 0.0;
+            double prediction = 0.0;
+            for (Map.Entry<Double, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+            return Row.join(row, Row.of(prediction));
+        }
+
+        /** Finds the nearest k neighbors from whole vectors in matrix. */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double d = 0.0;
+            for (int i = 0; i < input.size(); ++i) {
+                d += input.values[i] * input.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(input, d);
+            DenseMatrix packedFeatures = knnModelData.packedFeatures;
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new 
DenseVector(packedFeatures.numCols());
+            BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distances = new 
ArrayList<>(labelValues.length);
+            for (int i = 0; i < labelValues.length; i++) {
+                Tuple2<Double, Double> tuple2 = 
Tuple2.of(distanceVector.values[i], labelValues[i]);
+                distances.add(tuple2);
+            }
+            Tuple2<Double, Double> head = null;

Review comment:
       nits: Instead of maintaining this variable, would it be simpler (with 
the same efficiency) to use `priorityQueue.peek()` when ever it is needed?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+            double max = 0.0;
+            double prediction = 0.0;
+            for (Map.Entry<Double, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+            return Row.join(row, Row.of(prediction));
+        }
+
+        /** Finds the nearest k neighbors from whole vectors in matrix. */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double d = 0.0;
+            for (int i = 0; i < input.size(); ++i) {
+                d += input.values[i] * input.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(input, d);
+            DenseMatrix packedFeatures = knnModelData.packedFeatures;
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new 
DenseVector(packedFeatures.numCols());
+            BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distances = new 
ArrayList<>(labelValues.length);

Review comment:
       nits: could we give it a more meaningful name to improve readability, 
e.g. `distanceAndLabels`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+            double max = 0.0;
+            double prediction = 0.0;
+            for (Map.Entry<Double, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+            return Row.join(row, Row.of(prediction));
+        }
+
+        /** Finds the nearest k neighbors from whole vectors in matrix. */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double d = 0.0;
+            for (int i = 0; i < input.size(); ++i) {
+                d += input.values[i] * input.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(input, d);
+            DenseMatrix packedFeatures = knnModelData.packedFeatures;
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new 
DenseVector(packedFeatures.numCols());
+            BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distances = new 
ArrayList<>(labelValues.length);
+            for (int i = 0; i < labelValues.length; i++) {
+                Tuple2<Double, Double> tuple2 = 
Tuple2.of(distanceVector.values[i], labelValues[i]);
+                distances.add(tuple2);
+            }
+            Tuple2<Double, Double> head = null;
+            for (Tuple2<Double, Double> currentValue : distances) {
+                if (priorityQueue.size() < k) {
+                    priorityQueue.add(Tuple2.of(currentValue.f0, 
currentValue.f1));
+                    head = priorityQueue.peek();
+                } else {
+                    if (priorityQueue.comparator().compare(head, currentValue) 
< 0) {
+                        Tuple2<Double, Double> peek = priorityQueue.poll();
+                        if (peek != null) {
+                            peek.f0 = currentValue.f0;
+                            peek.f1 = currentValue.f1;
+                            priorityQueue.add(peek);
+                            head = priorityQueue.peek();
+                        }
+                    }
+                }
+            }
+            List<Double> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();

Review comment:
       nits: could we make these variable names more consistent with the above 
code and also more meaningful? Maybe something like `labels` and `distances`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+            double max = 0.0;
+            double prediction = 0.0;
+            for (Map.Entry<Double, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+            return Row.join(row, Row.of(prediction));
+        }
+
+        /** Finds the nearest k neighbors from whole vectors in matrix. */
+        private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector 
input) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =

Review comment:
       nits: would it be slightly better to rename this as `nearestKNeighbors`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);
+            for (Double obj : tuple2.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+            double max = 0.0;

Review comment:
       Would we rename this variable to be more self-explanation, e.g. 
`maxWeight`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder(),
+                0);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new 
ModelDataDecoder(), 0);
+        return model.setModelData(modelDataTable);
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            if (vector == null) {
+                return Row.join(row, new Row(1));
+            }
+            Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+            double percent = 1.0 / tuple2.f0.size();
+            Map<Double, Double> detail = new HashMap<>(0);

Review comment:
       nits: Would we make this variable name more self-explanatory, e.g. 
`labelWeights`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,214 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        /* Tuple2 : <sampleVector, label> */
+        DataStream<Tuple2<DenseVector, Double>> inputData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<DenseVector, 
Double>>() {
+                                    @Override
+                                    public Tuple2<DenseVector, Double> map(Row 
value) {
+                                        Double label = (Double) 
value.getField(getLabelCol());
+                                        DenseVector feature =
+                                                (DenseVector) 
value.getField(getFeaturesCol());
+                                        return Tuple2.of(feature, label);
+                                    }
+                                });
+        DataStream<KnnModelData> distributedModelData = 
prepareModelData(inputData);
+        DataStream<KnnModelData> modelData = 
mergeModelData(distributedModelData);
+        KnnModel model = new 
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Knn load(StreamExecutionEnvironment env, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Prepares distributed knn model data. Constructs the sample matrix and 
computes norm of
+     * vectors.
+     *
+     * @param inputData Input vector data with label.
+     * @return Distributed knn model.
+     */
+    private static DataStream<KnnModelData> prepareModelData(
+            DataStream<Tuple2<DenseVector, Double>> inputData) {
+        return DataStreamUtils.mapPartition(
+                inputData,
+                new RichMapPartitionFunction<Tuple2<DenseVector, Double>, 
KnnModelData>() {
+                    @Override
+                    public void mapPartition(
+                            Iterable<Tuple2<DenseVector, Double>> values,
+                            Collector<KnnModelData> out) {
+                        List<Tuple2<DenseVector, Double>> dataPoints = new 
ArrayList<>(0);
+                        for (Tuple2<DenseVector, Double> tuple2 : values) {
+                            dataPoints.add(tuple2);
+                        }
+                        int featureDim = dataPoints.get(0).f0.size();
+                        DenseMatrix packedFeatures = new 
DenseMatrix(featureDim, dataPoints.size());
+                        DenseVector labels = new 
DenseVector(dataPoints.size());
+                        for (int i = 0; i < dataPoints.size(); ++i) {
+                            Tuple2<DenseVector, Double> tuple2 = 
dataPoints.get(i);
+                            labels.values[i] = tuple2.f1;
+                            System.arraycopy(
+                                    tuple2.f0.values,
+                                    0,
+                                    packedFeatures.values,
+                                    i * featureDim,
+                                    featureDim);
+                        }
+                        DenseVector featureNorms = computeNorm(packedFeatures);
+                        if (dataPoints.size() > 0) {
+                            out.collect(new KnnModelData(packedFeatures, 
featureNorms, labels));
+                        }
+                    }
+                },
+                TypeInformation.of(KnnModelData.class));

Review comment:
       It seems that we can simplify the code here by calling 
`mapPartition(...)` without specifying this TypeInformation. Could you double 
check this?
   
   Same for `mergeModelData(...)`.

##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -26,9 +26,91 @@
     private static final dev.ludovic.netlib.BLAS JAVA_BLAS =
             dev.ludovic.netlib.JavaBLAS.getInstance();
 
-    /** y += a * x . */
+    /**
+     * \sum_i |x_i| .
+     *
+     * @param x x
+     * @return \sum_i |x_i|
+     */
+    public static double asum(DenseVector x) {
+        return JAVA_BLAS.dasum(x.size(), x.values, 0, 1);
+    }
+
+    /**
+     * y += a * x .
+     *
+     * @param a a
+     * @param x x
+     * @param y y
+     */
     public static void axpy(double a, DenseVector x, DenseVector y) {
         Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
         JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
     }
+
+    /**
+     * x \cdot y .
+     *
+     * @param x x
+     * @param y y
+     * @return x \cdot y
+     */
+    public static double dot(DenseVector x, DenseVector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
+        return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
+    }
+
+    /**
+     * \sqrt(\sum_i x_i * x_i) .
+     *
+     * @param x x
+     * @return \sqrt(\sum_i x_i * x_i)
+     */
+    public static double norm2(DenseVector x) {
+        return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
+    }
+
+    /**
+     * x = x * a .
+     *
+     * @param a a
+     * @param x x
+     */
+    public static void scal(double a, DenseVector x) {
+        JAVA_BLAS.dscal(x.size(), a, x.values, 1);
+    }
+
+    /**
+     * y = alpha * A * x + beta * y or y = alpha * (A^T) * x + beta * y.
+     *
+     * @param alpha alpha.
+     * @param A m x n matrix A.
+     * @param transA Whether transposes matrix y before multiply.
+     * @param x dense vector with size n.
+     * @param beta beta.
+     * @param y dense vector with size m.
+     */
+    public static void gemv(
+            double alpha,
+            DenseMatrix A,

Review comment:
       It seems that this line cause the style check failure?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to