lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r768422566



##########
File path: flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java
##########
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests the {@link BLAS}. */
+public class BLASTest {
+
+    private static final double TOLERANCE = 1e-7;
+
+    private DenseVector inputDenseVec = Vectors.dense(1, -2, 3, 4, -5);
+    private DenseMatrix inputDenseMat =

Review comment:
       nits: could we change both variables to be `final`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,224 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private Table predictData;
+    private static final List<Row> trainRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(Vectors.dense(2.0, 3.0), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(200.1, 300.1), 2.0),
+                            Row.of(Vectors.dense(200.2, 300.2), 2.0),
+                            Row.of(Vectors.dense(200.3, 300.3), 2.0),
+                            Row.of(Vectors.dense(200.4, 300.4), 2.0),
+                            Row.of(Vectors.dense(200.4, 300.4), 2.0),
+                            Row.of(Vectors.dense(200.6, 300.6), 2.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0),
+                            Row.of(Vectors.dense(2.3, 3.2), 1.0),
+                            Row.of(Vectors.dense(2.3, 3.2), 1.0),
+                            Row.of(Vectors.dense(2.8, 3.2), 3.0),
+                            Row.of(Vectors.dense(300., 3.2), 4.0),
+                            Row.of(Vectors.dense(2.2, 3.2), 1.0),
+                            Row.of(Vectors.dense(2.4, 3.2), 5.0),
+                            Row.of(Vectors.dense(2.5, 3.2), 5.0),
+                            Row.of(Vectors.dense(2.5, 3.2), 5.0),
+                            Row.of(Vectors.dense(2.1, 3.1), 1.0)));
+    private static final List<Row> predictRows =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(Vectors.dense(4.0, 4.1), 5.0),
+                            Row.of(Vectors.dense(300, 42), 2.0)));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(trainRows);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("features", 
"label");
+        DataStream<Row> predDataStream = env.fromCollection(predictRows);
+        predictData = tEnv.fromDataStream(predDataStream, 
schema).as("features", "label");
+    }
+
+    private static void verifyPredictionResult(Table output, String labelCol, 
String predictionCol)
+            throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+        DataStream<Tuple2<Double, Double>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Double, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Double, Double> map(Row row) 
{
+                                        return Tuple2.of(
+                                                (Double) 
row.getField(labelCol),
+                                                (Double) 
row.getField(predictionCol));
+                                    }
+                                });
+        List<Tuple2<Double, Double>> result = 
IteratorUtils.toList(stream.executeAndCollect());
+        for (Tuple2<Double, Double> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    @Test
+    public void testParam() {
+        Knn knn = new Knn();
+        assertEquals("features", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(5, (int) knn.getK());
+        assertEquals("prediction", knn.getPredictionCol());
+        knn.setLabelCol("test_label")
+                .setFeaturesCol("test_features")
+                .setK(4)
+                .setPredictionCol("test_prediction");
+        assertEquals("test_features", knn.getFeaturesCol());
+        assertEquals("test_label", knn.getLabelCol());
+        assertEquals(4, (int) knn.getK());
+        assertEquals("test_prediction", knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFeaturePredictionParam() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("test_label")
+                        .setFeaturesCol("test_features")
+                        .setK(4)
+                        .setPredictionCol("test_prediction");
+        KnnModel model = knn.fit(trainData.as("test_features, test_label"));
+        Table output = model.transform(predictData.as("test_features, 
test_label"))[0];
+        assertEquals(
+                Arrays.asList("test_features", "test_label", 
"test_prediction"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFewerDistinctPointsThanCluster() throws Exception {
+        Knn knn = new Knn();
+        KnnModel model = knn.fit(predictData);
+        Table output = model.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), 
knn.getPredictionCol());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        Knn knn = new Knn();
+        KnnModel knnModel = knn.fit(trainData);
+        Table output = knnModel.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), 
knn.getPredictionCol());
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        Knn knn = new Knn();
+        Knn loadedKnn =
+                StageTestUtils.saveAndReload(env, knn, 
tempFolder.newFolder().getAbsolutePath());
+        KnnModel knnModel = loadedKnn.fit(trainData);
+        Table output = knnModel.transform(predictData)[0];
+        verifyPredictionResult(output, knn.getLabelCol(), 
knn.getPredictionCol());
+    }
+
+    @Test
+    public void testModelSaveLoadAndPredict() throws Exception {

Review comment:
       For code consistency, could we remove this test and update 
`testSaveLoadAndPredict` to test the save/load of both the estimator and the 
model, similar to `LogisticRegressionTest::testSaveLoadAndPredict`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<KnnModelData> modelData =
+                ReadWriteUtils.loadModelData(env, path, new 
KnnModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            double prediction = predict(feature);
+            return Row.join(row, Row.of(prediction));
+        }
+
+        private double predict(DenseVector inputFeature) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double featureNorm = 0.0;
+            for (int i = 0; i < inputFeature.size(); ++i) {
+                featureNorm += inputFeature.values[i] * inputFeature.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(inputFeature, 
featureNorm);

Review comment:
       nits: how about we rename this to `featureAndNorm`?
   
   This is probably the general way how we name a Tuple2 variable. And it would 
be consistent with the `distanceAndLabels` variable names used below.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<KnnModelData> modelData =
+                ReadWriteUtils.loadModelData(env, path, new 
KnnModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            double prediction = predict(feature);
+            return Row.join(row, Row.of(prediction));
+        }
+
+        private double predict(DenseVector inputFeature) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double featureNorm = 0.0;
+            for (int i = 0; i < inputFeature.size(); ++i) {
+                featureNorm += inputFeature.values[i] * inputFeature.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(inputFeature, 
featureNorm);
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new DenseVector(labelValues.length);
+            BLAS.gemv(-2.0, knnModelData.packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distanceAndLabels = new 
ArrayList<>(labelValues.length);
+            for (int i = 0; i < labelValues.length; i++) {
+                distanceAndLabels.add(Tuple2.of(distanceVector.values[i], 
labelValues[i]));
+            }
+            for (Tuple2<Double, Double> distanceAndLabel : distanceAndLabels) {
+                if (priorityQueue.size() < k) {
+                    priorityQueue.add(Tuple2.of(distanceAndLabel.f0, 
distanceAndLabel.f1));
+                } else {
+                    Tuple2<Double, Double> head = priorityQueue.peek();
+                    if (priorityQueue.comparator().compare(head, 
distanceAndLabel) < 0) {
+                        Tuple2<Double, Double> peek = priorityQueue.poll();
+                        if (peek != null) {

Review comment:
       nits: It seems that `peek != null` is guaranteed to be true. Could we 
remove this check?
   
   Instead of inserting `peak` after mutating `peak`, would it be simpler to 
insert `distanceAndLabel` directly? For exmaple:
   
   
   ```
   if (priorityQueue.size() < k) {
       priorityQueue.add(Tuple2.of(distanceAndLabel.f0, distanceAndLabel.f1));
   } else {
       Tuple2<Double, Double> head = priorityQueue.peek();
       if (head.f0 > distanceAndLabel.f0) {
           Tuple2<Double, Double> peek = priorityQueue.poll();
           priorityQueue.add(distanceAndLabel);
       }
   }
   ```

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<KnnModelData> modelData =
+                ReadWriteUtils.loadModelData(env, path, new 
KnnModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            double prediction = predict(feature);
+            return Row.join(row, Row.of(prediction));
+        }
+
+        private double predict(DenseVector inputFeature) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double featureNorm = 0.0;
+            for (int i = 0; i < inputFeature.size(); ++i) {
+                featureNorm += inputFeature.values[i] * inputFeature.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(inputFeature, 
featureNorm);
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new DenseVector(labelValues.length);
+            BLAS.gemv(-2.0, knnModelData.packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distanceAndLabels = new 
ArrayList<>(labelValues.length);
+            for (int i = 0; i < labelValues.length; i++) {
+                distanceAndLabels.add(Tuple2.of(distanceVector.values[i], 
labelValues[i]));
+            }
+            for (Tuple2<Double, Double> distanceAndLabel : distanceAndLabels) {
+                if (priorityQueue.size() < k) {
+                    priorityQueue.add(Tuple2.of(distanceAndLabel.f0, 
distanceAndLabel.f1));

Review comment:
       Instead of instantiating a new Tuple2, could we insert 
`distanceAndLabel` directly?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {

Review comment:
       nits: for better code consistency with other algorithms, could we change 
the input parameter name to `inputs`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<KnnModelData> modelData =
+                ReadWriteUtils.loadModelData(env, path, new 
KnnModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            double prediction = predict(feature);
+            return Row.join(row, Row.of(prediction));
+        }
+
+        private double predict(DenseVector inputFeature) {
+            PriorityQueue<Tuple2<Double, Double>> priorityQueue =
+                    new PriorityQueue<>(Comparator.comparingDouble(o -> 
-o.f0));
+            double featureNorm = 0.0;
+            for (int i = 0; i < inputFeature.size(); ++i) {
+                featureNorm += inputFeature.values[i] * inputFeature.values[i];
+            }
+            Tuple2<DenseVector, Double> sample = Tuple2.of(inputFeature, 
featureNorm);
+            double[] labelValues = knnModelData.labels.values;
+            double[] normValues = knnModelData.featureNorms.values;
+            DenseVector distanceVector = new DenseVector(labelValues.length);
+            BLAS.gemv(-2.0, knnModelData.packedFeatures, true, sample.f0, 0.0, 
distanceVector);
+            for (int i = 0; i < distanceVector.values.length; i++) {
+                distanceVector.values[i] =
+                        Math.sqrt(Math.abs(distanceVector.values[i] + 
sample.f1 + normValues[i]));
+            }
+            List<Tuple2<Double, Double>> distanceAndLabels = new 
ArrayList<>(labelValues.length);
+            for (int i = 0; i < labelValues.length; i++) {
+                distanceAndLabels.add(Tuple2.of(distanceVector.values[i], 
labelValues[i]));
+            }
+            for (Tuple2<Double, Double> distanceAndLabel : distanceAndLabels) {
+                if (priorityQueue.size() < k) {
+                    priorityQueue.add(Tuple2.of(distanceAndLabel.f0, 
distanceAndLabel.f1));
+                } else {
+                    Tuple2<Double, Double> head = priorityQueue.peek();
+                    if (priorityQueue.comparator().compare(head, 
distanceAndLabel) < 0) {
+                        Tuple2<Double, Double> peek = priorityQueue.poll();
+                        if (peek != null) {
+                            peek.f0 = distanceAndLabel.f0;
+                            peek.f1 = distanceAndLabel.f1;
+                            priorityQueue.add(peek);
+                        }
+                    }
+                }
+            }
+            List<Double> labels = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Double> distanceAndLabel = priorityQueue.poll();
+                labels.add(distanceAndLabel.f1);
+            }
+            double percent = 1.0 / labels.size();
+            Map<Double, Double> labelWeights = new HashMap<>(0);
+            for (Double label : labels) {
+                labelWeights.merge(label, percent, Double::sum);
+            }
+            double maxWeight = 0.0;
+            double prediction = 0.0;

Review comment:
       It seems that `label` or `predictedLabel` would be slightly more 
self-explanatory. And maybe rename this method from `predict` to `predictLabel` 
to be more consistent with `PredictLabelFunction`. What do you think?

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,69 @@
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Column-major dense matrix. The entry values are stored in a single array of 
doubles with columns
+ * listed in sequence.
+ */
+@TypeInfo(DenseMatrixTypeInfoFactory.class)
+public class DenseMatrix implements Matrix {
+
+    /** Row dimension. */
+    private final int numRows;
+
+    /** Column dimension. */
+    private final int numCols;
+
+    /**
+     * Array for internal storage of elements.
+     *
+     * <p>The matrix data is stored in column major format internally.
+     */
+    public final double[] values;
+
+    /**
+     * Constructs an m-by-n matrix of zeros.
+     *
+     * @param numRows Number of rows.
+     * @param numCols Number of columns.
+     */
+    public DenseMatrix(int numRows, int numCols) {
+        this(numRows, numCols, new double[numRows * numCols]);
+    }
+
+    /**
+     * Constructs a matrix from a 1-D array. The data in the array should be 
organized in column
+     * major.
+     *
+     * @param numRows Number of rows.
+     * @param numCols Number of cols.
+     * @param values One-dimensional array of doubles.
+     */
+    public DenseMatrix(int numRows, int numCols, double[] values) {
+        Preconditions.checkArgument(values.length == numRows * numCols);
+        this.numRows = numRows;
+        this.numCols = numCols;
+        this.values = values;
+    }
+
+    @Override
+    public int numRows() {
+        return numRows;
+    }
+
+    @Override
+    public int numCols() {
+        return numCols;
+    }
+
+    @Override
+    public double get(int i, int j) {
+        if (i >= numRows || j >= numCols) {

Review comment:
       nits: could we also verify that `i >= 0 && y >= 0`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,155 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        /* Tuple3 : <feature, label, norm> */
+        DataStream<Tuple3<DenseVector, Double, Double>> inputDataWithNorm =
+                computeNorm(tEnv.toDataStream(inputs[0]));
+        DataStream<KnnModelData> modelData = genModelData(inputDataWithNorm);
+        KnnModel model = new 
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static Knn load(StreamExecutionEnvironment env, String path) throws 
IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /**
+     * Generates knn model data.
+     *
+     * @param inputDataWithNorm Input data with feature norm.
+     * @return Knn model.
+     */
+    private static DataStream<KnnModelData> genModelData(
+            DataStream<Tuple3<DenseVector, Double, Double>> inputDataWithNorm) 
{
+        DataStream<KnnModelData> modelData =
+                DataStreamUtils.mapPartition(
+                        inputDataWithNorm,
+                        new RichMapPartitionFunction<
+                                Tuple3<DenseVector, Double, Double>, 
KnnModelData>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple3<DenseVector, Double, 
Double>> values,
+                                    Collector<KnnModelData> out) {
+                                List<Tuple3<DenseVector, Double, Double>> 
buffer =
+                                        new ArrayList<>(1);

Review comment:
       hmm... it is not clear why we explicitly set initial size to be 1. 
Actually the initialize size is likely to be much large than 1 in most cases. 
Could we simplify the code to just use `new ArrayList<>()`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link 
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        modelDataTable = modelData[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<KnnModelData> knnModel = 
KnnModelData.getModelDataStream(modelDataTable);
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
BasicTypeInfo.DOUBLE_TYPE_INFO),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getPredictionCol()));
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, knnModel),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictLabelFunction(
+                                            broadcastModelKey, getK(), 
getFeaturesCol()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KnnModelData.getModelDataStream(modelDataTable),
+                path,
+                new KnnModelData.ModelDataEncoder());
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<KnnModelData> modelData =
+                ReadWriteUtils.loadModelData(env, path, new 
KnnModelData.ModelDataDecoder());
+        return model.setModelData(tEnv.fromDataStream(modelData));
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featureCol;
+        private KnnModelData knnModelData;
+        private final int k;
+        private final String broadcastKey;
+
+        public PredictLabelFunction(String broadcastKey, int k, String 
featureCol) {
+            this.k = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (knnModelData == null) {
+                knnModelData =
+                        (KnnModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+            }
+            DenseVector feature = (DenseVector) row.getField(featureCol);
+            double prediction = predict(feature);
+            return Row.join(row, Row.of(prediction));
+        }
+
+        private double predict(DenseVector inputFeature) {

Review comment:
       nits: could we simplify the param name to `feature`. This would also be 
more consistent with `LogisticRegressionModel::predictRaw(..)` and be 
consistent with the name of the variable used by the caller.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to