zhipeng93 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763924687



##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@code DenseMatrix}. */

Review comment:
       can you change `@code` to `@link`?

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@code DenseMatrix}. */
+public final class DenseMatrixSerializer extends 
TypeSerializerSingleton<DenseMatrix> {
+
+    private static final long serialVersionUID = 1L;
+
+    private static final double[] EMPTY = new double[0];
+
+    private static final DenseMatrixSerializer INSTANCE = new 
DenseMatrixSerializer();
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public DenseMatrix createInstance() {
+        return new DenseMatrix(0, 0, EMPTY);
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from) {
+        return new DenseMatrix(
+                from.numRows, from.numCols, Arrays.copyOf(from.values, 
from.values.length));
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) {
+        if (from.values.length == reuse.values.length) {
+            System.arraycopy(from.values, 0, reuse.values, 0, 
from.values.length);
+            return reuse;
+        }
+        return copy(from);
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(DenseMatrix matrix, DataOutputView target) throws 
IOException {
+        if (matrix == null) {
+            throw new IllegalArgumentException("The matrix must not be null.");
+        }
+
+        final int len = matrix.values.length;
+        target.writeInt(matrix.numRows);
+        target.writeInt(matrix.numCols);
+        for (int i = 0; i < len; i++) {
+            target.writeDouble(matrix.values[i]);
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DataInputView source) throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+        double[] values = new double[m * n];
+        for (int i = 0; i < m * n; i++) {
+            values[i] = source.readDouble();
+        }
+        return new DenseMatrix(m, n, values);
+    }
+
+    private static void readDoubleArray(double[] dst, DataInputView source, 
int len)

Review comment:
       Can we make the coding style consistent?
   For example, can we also use `readDoubleArray` in `public DenseMatrix 
deserialize(DataInputView source)` or just remove this method?
   
   

##########
File path: flink-ml-lib/pom.xml
##########
@@ -106,6 +106,11 @@ under the License.
       <type>jar</type>
       <scope>test</scope>
     </dependency>
+    <dependency>

Review comment:
       I think this dependency is unnecessary here, given that we move all 
infra to `flink-ml-core`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##########
@@ -0,0 +1,148 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+/** Knn model data, which stores the data used to calculate the distances 
between nodes. */
+public class KnnModelData implements Serializable, Cloneable {
+    private final List<Tuple3<DenseMatrix, DenseVector, Integer[]>> dictData;
+    private final Comparator<? super Tuple2<Double, Integer>> comparator;
+
+    /**
+     * Constructor.
+     *
+     * @param list Row list.
+     */
+    public KnnModelData(List<Row> list) {
+        this.dictData = new ArrayList<>(list.size());
+        for (Row row : list) {
+            this.dictData.add(
+                    Tuple3.of(
+                            (DenseMatrix) row.getField(0),
+                            (DenseVector) row.getField(1),
+                            (Integer[]) row.getField(2)));
+        }
+        comparator = Comparator.comparingDouble(o -> -o.f0);
+    }
+
+    /**
+     * Gets comparator.
+     *
+     * @return Comparator.
+     */
+    public Comparator<? super Tuple2<Double, Integer>> getQueueComparator() {
+        return comparator;
+    }
+
+    /**
+     * Gets dictionary data size.
+     *
+     * @return Dictionary data size.
+     */
+    public Integer getLength() {
+        return dictData.size();
+    }
+
+    public List<Tuple3<DenseMatrix, DenseVector, Integer[]>> getDictData() {
+        return dictData;
+    }
+
+    /** Encoder for the Knn model data. */
+    public static class ModelDataEncoder implements Encoder<Row> {
+        @Override
+        public void encode(Row modelData, OutputStream outputStream) {
+            Kryo kryo = new Kryo();

Review comment:
       can you reuse the DensevectorSerializer here? similar as LR did.

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,277 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final String LABEL_COL = "label";
+    private static final String PRED_COL = "pred";
+    private static final String VEC_COL = "vec";
+    List<Row> trainArray =

Review comment:
       nits: `trainArray` could be private here.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,397 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table[] modelData;
+
+    /** Constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Sets model data for knn prediction.
+     *
+     * @param modelData Knn model data.
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * Gets model data.
+     *
+     * @return Table array.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * Predicts label with knn model.
+     *
+     * @param inputs List of tables.
+     * @return Prediction result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+        final String broadcastKey = "broadcastModelKey";
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(broadcastKey, model);
+        String resultCols = getPredictionCol();
+        DataType resultTypes = DataTypes.INT();
+        ResolvedSchema outputSchema =
+                TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            TableUtils.getRowTypeInfo(outputSchema),
+                            new PredictOperator(
+                                    inputs[0].getResolvedSchema(),
+                                    broadcastKey,
+                                    getK(),
+                                    getFeaturesCol()));
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator loads model data and does the prediction. */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+            implements OneInputStreamOperator<Row, Row> {
+
+        private boolean firstEle = true;
+        private final String[] reservedCols;
+        private final String featureCol;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private final String broadcastKey;
+
+        public PredictOperator(
+                ResolvedSchema dataSchema, String broadcastKey, int k, String 
featureCol) {
+            super(new AbstractRichFunction() {});
+            reservedCols = dataSchema.getColumnNames().toArray(new String[0]);
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(map(value)));
+        }
+
+        public Row map(Row row) throws Exception {

Review comment:
       `public` -> private?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##########
@@ -0,0 +1,148 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.types.Row;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+
+/** Knn model data, which stores the data used to calculate the distances 
between nodes. */
+public class KnnModelData implements Serializable, Cloneable {

Review comment:
       can you remove `implements Serializable, Cloneable` here?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,217 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm. KNN is to classify 
unlabeled observations by
+ * assigning them to the class of the most similar labeled examples.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** Constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Fits data and produces knn model.
+     *
+     * @param inputs A list of tables, including train data table.
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String labelCol = getLabelCol();
+
+        DataStream<Tuple2<DenseVector, Integer>> trainData =
+                input.map(
+                        new MapFunction<Row, Tuple2<DenseVector, Integer>>() {
+                            @Override
+                            public Tuple2<DenseVector, Integer> map(Row value) 
{
+                                Integer label = (Integer) 
value.getField(labelCol);
+                                DenseVector vec = (DenseVector) 
value.getField(getFeaturesCol());
+                                return Tuple2.of(vec, label);
+                            }
+                        });
+
+        DataStream<Row> model = buildModel(trainData);
+        KnnModel knnModel =

Review comment:
       can you use `ReadWriteUtils.updateExistingParams(model, paramMap)`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,277 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+    private static final String LABEL_COL = "label";
+    private static final String PRED_COL = "pred";
+    private static final String VEC_COL = "vec";
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(1, Vectors.dense(2.0, 3.0)),
+                            Row.of(1, Vectors.dense(2.1, 3.1)),
+                            Row.of(2, Vectors.dense(200.1, 300.1)),
+                            Row.of(2, Vectors.dense(200.2, 300.2)),
+                            Row.of(2, Vectors.dense(200.3, 300.3)),
+                            Row.of(2, Vectors.dense(200.4, 300.4)),
+                            Row.of(2, Vectors.dense(200.4, 300.4)),
+                            Row.of(2, Vectors.dense(200.6, 300.6)),
+                            Row.of(1, Vectors.dense(2.1, 3.1)),
+                            Row.of(1, Vectors.dense(2.1, 3.1)),
+                            Row.of(1, Vectors.dense(2.1, 3.1)),
+                            Row.of(1, Vectors.dense(2.1, 3.1)),
+                            Row.of(1, Vectors.dense(2.3, 3.2)),
+                            Row.of(1, Vectors.dense(2.3, 3.2)),
+                            Row.of(3, Vectors.dense(2.8, 3.2)),
+                            Row.of(4, Vectors.dense(300., 3.2)),
+                            Row.of(1, Vectors.dense(2.2, 3.2)),
+                            Row.of(5, Vectors.dense(2.4, 3.2)),
+                            Row.of(5, Vectors.dense(2.5, 3.2)),
+                            Row.of(5, Vectors.dense(2.5, 3.2)),
+                            Row.of(1, Vectors.dense(2.1, 3.1))));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of(5, Vectors.dense(4.0, 4.1)), Row.of(2, 
Vectors.dense(300, 42))));
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.INT())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + "," 
+ VEC_COL);
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, schema).as(LABEL_COL + 
"," + VEC_COL);
+    }
+
+    // Executes the graph and returns a list which has true label and predict 
label.
+    private static List<Tuple2<String, String>> executeAndCollect(Table 
output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+
+        DataStream<Tuple2<Integer, Integer>> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                new MapFunction<Row, Tuple2<Integer, 
Integer>>() {
+                                    @Override
+                                    public Tuple2<Integer, Integer> map(Row 
row) {
+                                        return Tuple2.of(
+                                                (Integer) 
row.getField(LABEL_COL),
+                                                (Integer) 
row.getField(PRED_COL));
+                                    }
+                                });
+        return IteratorUtils.toList(stream.executeAndCollect());
+    }
+
+    private static void verifyClusteringResult(List<Tuple2<String, String>> 
result) {
+        for (Tuple2<String, String> t2 : result) {
+            Assert.assertEquals(t2.f0, t2.f1);
+        }
+    }
+
+    /** Tests Param. */
+    @Test
+    public void testParam() {
+        Knn knnOrigin = new Knn();
+        assertEquals(LABEL_COL, knnOrigin.getLabelCol());
+        assertEquals(10L, knnOrigin.getK().longValue());
+        assertEquals("prediction", knnOrigin.getPredictionCol());
+
+        Knn knn =
+                new Knn()
+                        .setLabelCol(LABEL_COL)

Review comment:
       can we use params that are different from default params?

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@code DenseMatrix}. */
+public final class DenseMatrixSerializer extends 
TypeSerializerSingleton<DenseMatrix> {
+
+    private static final long serialVersionUID = 1L;
+
+    private static final double[] EMPTY = new double[0];
+
+    private static final DenseMatrixSerializer INSTANCE = new 
DenseMatrixSerializer();
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public DenseMatrix createInstance() {
+        return new DenseMatrix(0, 0, EMPTY);
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from) {
+        return new DenseMatrix(
+                from.numRows, from.numCols, Arrays.copyOf(from.values, 
from.values.length));
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) {
+        if (from.values.length == reuse.values.length) {
+            System.arraycopy(from.values, 0, reuse.values, 0, 
from.values.length);
+            return reuse;
+        }
+        return copy(from);
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(DenseMatrix matrix, DataOutputView target) throws 
IOException {
+        if (matrix == null) {
+            throw new IllegalArgumentException("The matrix must not be null.");
+        }
+
+        final int len = matrix.values.length;
+        target.writeInt(matrix.numRows);
+        target.writeInt(matrix.numCols);
+        for (int i = 0; i < len; i++) {
+            target.writeDouble(matrix.values[i]);
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DataInputView source) throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+        double[] values = new double[m * n];
+        for (int i = 0; i < m * n; i++) {
+            values[i] = source.readDouble();
+        }
+        return new DenseMatrix(m, n, values);
+    }
+
+    private static void readDoubleArray(double[] dst, DataInputView source, 
int len)
+            throws IOException {
+        for (int i = 0; i < len; i++) {
+            dst[i] = source.readDouble();
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DenseMatrix reuse, DataInputView source) 
throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+
+        double[] values = new double[m * n];
+        readDoubleArray(values, source, m * n);
+        return new DenseMatrix(m, n, values);
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+        int m = source.readInt();
+        target.writeInt(m);
+        int n = source.readInt();
+        target.writeInt(n);
+
+        target.write(source, m * n * 8);

Review comment:
       can we use `Double.BYTES` here?

##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/** A matrix of double values. */
+public interface Matrix extends Serializable {
+
+    /** Gets num row. */

Review comment:
       nits: Gets number of rows.

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfoFactory.java
##########
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeinfo.TypeInfoFactory;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.lang.reflect.Type;
+import java.util.Map;
+
+/**
+ * Used by {@link TypeExtractor} to create a {@link TypeInformation} for 
implementations of {@link
+ * DenseMatrix}.
+ */
+public class DenseMatrixTypeInfoFactory extends TypeInfoFactory<DenseMatrix> {
+
+    @Override
+    public TypeInformation<DenseMatrix> createTypeInfo(
+            Type t, Map<String, TypeInformation<?>> genericParameters) {
+

Review comment:
       nits: remove the empty line here

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,217 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm. KNN is to classify 
unlabeled observations by
+ * assigning them to the class of the most similar labeled examples.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** Constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Fits data and produces knn model.
+     *
+     * @param inputs A list of tables, including train data table.
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String labelCol = getLabelCol();
+
+        DataStream<Tuple2<DenseVector, Integer>> trainData =
+                input.map(
+                        new MapFunction<Row, Tuple2<DenseVector, Integer>>() {
+                            @Override
+                            public Tuple2<DenseVector, Integer> map(Row value) 
{
+                                Integer label = (Integer) 
value.getField(labelCol);
+                                DenseVector vec = (DenseVector) 
value.getField(getFeaturesCol());
+                                return Tuple2.of(vec, label);
+                            }
+                        });
+
+        DataStream<Row> model = buildModel(trainData);
+        KnnModel knnModel =
+                new KnnModel()
+                        .setFeaturesCol(getFeaturesCol())
+                        .setK(getK())
+                        .setPredictionCol(getPredictionCol());
+        knnModel.setModelData(tEnv.fromDataStream(model, 
KnnModelData.getModelSchema()));
+        return knnModel;
+    }
+
+    /**
+     * Builds knn model.
+     *
+     * @param dataStream Input data.
+     * @return Knn model.
+     */
+    private static DataStream<Row> buildModel(DataStream<Tuple2<DenseVector, 
Integer>> dataStream) {
+        Schema schema = KnnModelData.getModelSchema();
+        return dataStream.transform(
+                "build knn model",
+                TableUtils.getRowTypeInfo(schema),
+                new MapPartitionFunctionWrapper<>(
+                        new RichMapPartitionFunction<Tuple2<DenseVector, 
Integer>, Row>() {
+                            @Override
+                            public void mapPartition(
+                                    Iterable<Tuple2<DenseVector, Integer>> 
values,
+                                    Collector<Row> out) {
+                                List<Tuple3<DenseMatrix, DenseVector, 
Integer[]>> list =
+                                        prepareMatrixData(values);
+                                for (Tuple3<DenseMatrix, DenseVector, 
Integer[]> t3 : list) {
+                                    Row ret = new Row(3);
+                                    ret.setField(0, t3.f0);
+                                    ret.setField(1, t3.f1);
+                                    ret.setField(2, t3.f2);
+                                    out.collect(ret);
+                                }
+                            }
+                        }));
+    }
+
+    /**
+     * Prepares matrix data, the output is a list of Tuple3, which includes 
vectors, vecNorms and
+     * label.
+     *
+     * @param trainData Input train data.
+     * @return Model data in format of list tuple3.
+     */
+    private static List<Tuple3<DenseMatrix, DenseVector, Integer[]>> 
prepareMatrixData(
+            Iterable<Tuple2<DenseVector, Integer>> trainData) {
+        final int size = 5 * 1024 * 1024;

Review comment:
       what is the intuition for choose `size=5*1024*1024`?

##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasK.java
##########
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared K param. */
+public interface HasK<T> extends WithParams<T> {

Review comment:
       @lindong28 Shall we keep `K` as an inner interface for Knn here? Similar 
as KMeans did now.
   
   Many machine learning algs contains a parameter `k`, but they share 
different semantics. Combining all of them seems a bit over-design to me.

##########
File path: 
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java
##########
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/** Specialized serializer for {@code DenseMatrix}. */
+public final class DenseMatrixSerializer extends 
TypeSerializerSingleton<DenseMatrix> {
+
+    private static final long serialVersionUID = 1L;
+
+    private static final double[] EMPTY = new double[0];
+
+    private static final DenseMatrixSerializer INSTANCE = new 
DenseMatrixSerializer();
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public DenseMatrix createInstance() {
+        return new DenseMatrix(0, 0, EMPTY);
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from) {
+        return new DenseMatrix(
+                from.numRows, from.numCols, Arrays.copyOf(from.values, 
from.values.length));
+    }
+
+    @Override
+    public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) {
+        if (from.values.length == reuse.values.length) {
+            System.arraycopy(from.values, 0, reuse.values, 0, 
from.values.length);
+            return reuse;
+        }
+        return copy(from);
+    }
+
+    @Override
+    public int getLength() {
+        return -1;
+    }
+
+    @Override
+    public void serialize(DenseMatrix matrix, DataOutputView target) throws 
IOException {
+        if (matrix == null) {
+            throw new IllegalArgumentException("The matrix must not be null.");
+        }
+
+        final int len = matrix.values.length;
+        target.writeInt(matrix.numRows);
+        target.writeInt(matrix.numCols);
+        for (int i = 0; i < len; i++) {
+            target.writeDouble(matrix.values[i]);
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DataInputView source) throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+        double[] values = new double[m * n];
+        for (int i = 0; i < m * n; i++) {
+            values[i] = source.readDouble();
+        }
+        return new DenseMatrix(m, n, values);
+    }
+
+    private static void readDoubleArray(double[] dst, DataInputView source, 
int len)
+            throws IOException {
+        for (int i = 0; i < len; i++) {
+            dst[i] = source.readDouble();
+        }
+    }
+
+    @Override
+    public DenseMatrix deserialize(DenseMatrix reuse, DataInputView source) 
throws IOException {
+        int m = source.readInt();
+        int n = source.readInt();
+
+        double[] values = new double[m * n];
+        readDoubleArray(values, source, m * n);
+        return new DenseMatrix(m, n, values);
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+        int m = source.readInt();
+        target.writeInt(m);
+        int n = source.readInt();
+        target.writeInt(n);
+
+        target.write(source, m * n * 8);
+    }
+
+    // ------------------------------------------------------------------------
+
+    @Override
+    public TypeSerializerSnapshot<DenseMatrix> snapshotConfiguration() {
+        return new DenseMatrixSerializerSnapshot();
+    }
+
+    /** Serializer configuration snapshot for compatibility and format 
evolution. */
+    @SuppressWarnings("WeakerAccess")

Review comment:
       what is the functionality of this method?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to