lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762714568



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+    /**
+     * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+     * saves m vectors, then the number of rows of <code>vectors</code> is n 
and the number of cols
+     * of <code>vectors</code> is m.
+     */
+    public final DenseMatrix vectors;
+    /**
+     * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+     * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+     * FastDistanceMatrixData, the length of <code>rows</code> is equal to the 
number of cols of
+     * <code>matrix</code>. Besides, the order of the rows are the same with 
the vectors.
+     */
+    public final String[] ids;
+
+    /**
+     * Stores some extra info extracted from the vector. It's also organized 
in columns. For
+     * example, if we want to save the L1 norm and L2 norm of the vector, then 
the two values are
+     * viewed as a two-dimension label vector. We organize the norm vectors 
together to get the
+     * <code>label</code>. If the number of cols of <code>vectors</code> is m, 
then in this case the
+     * dimension of <code>label</code> is 2 * m.
+     */
+    public DenseMatrix label;
+
+    public String[] getIds() {
+        return ids;
+    }
+
+    /**
+     * Constructor, initialize the vector data and extra info.
+     *
+     * @param vectors DenseMatrix which saves vectors in columns.
+     * @param ids extra info besides the vector.
+     */
+    public FastDistanceMatrixData(DenseMatrix vectors, String[] ids) {
+        this.ids = ids;
+        Preconditions.checkNotNull(vectors, "DenseMatrix should not be null!");
+        if (null != ids) {
+            Preconditions.checkArgument(
+                    vectors.numCols() == ids.length,
+                    "The column number of DenseMatrix must be equal to the 
rows array length!");
+        }
+        this.vectors = vectors;
+    }
+
+    /**
+     * serialization of FastDistanceMatrixData.
+     *
+     * @return json string.
+     */
+    @Override
+    public String toString() {

Review comment:
       In general `toString()` is used for debugging/informational purpose. 
Could we move the serialization/deserialization logic to dedicated class, 
similar to `DenseVectorTypeInfoFactory`?
   
   We probably need to have `DenseMatrixTypeInfoFactory` to 
serialize/deserialize `DenseMatrix`, similar to how 
`DenseVectorTypeInfoFactory` can be used to serialize/deserialize 
`DenseVector`. Then we can re-use this logic to serialize/deserialize 
`FastDistanceMatrixData`, which is effectively a few `DenseMatrix`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+    /**
+     * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+     * saves m vectors, then the number of rows of <code>vectors</code> is n 
and the number of cols
+     * of <code>vectors</code> is m.
+     */
+    public final DenseMatrix vectors;
+    /**
+     * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+     * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+     * FastDistanceMatrixData, the length of <code>rows</code> is equal to the 
number of cols of
+     * <code>matrix</code>. Besides, the order of the rows are the same with 
the vectors.
+     */
+    public final String[] ids;
+
+    /**
+     * Stores some extra info extracted from the vector. It's also organized 
in columns. For
+     * example, if we want to save the L1 norm and L2 norm of the vector, then 
the two values are
+     * viewed as a two-dimension label vector. We organize the norm vectors 
together to get the
+     * <code>label</code>. If the number of cols of <code>vectors</code> is m, 
then in this case the
+     * dimension of <code>label</code> is 2 * m.
+     */
+    public DenseMatrix label;
+
+    public String[] getIds() {

Review comment:
       In general we don't need `getXXX` for public final fields. Could this be 
removed?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistance.java
##########
@@ -0,0 +1,192 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * FastDistance is an accelerated distance calculating method. It use matrix 
vector operation to
+ * improve speed of distance calculating.
+ *
+ * <p>The distance type in this class is euclidean distance:
+ *
+ * <p>https://en.wikipedia.org/wiki/Euclidean_distance
+ */
+public class FastDistance implements Serializable {

Review comment:
       Why does this class needs to be Serializable? Should all methods of this 
class be static?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistance.java
##########
@@ -0,0 +1,192 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * FastDistance is an accelerated distance calculating method. It use matrix 
vector operation to
+ * improve speed of distance calculating.
+ *
+ * <p>The distance type in this class is euclidean distance:
+ *
+ * <p>https://en.wikipedia.org/wiki/Euclidean_distance
+ */
+public class FastDistance implements Serializable {
+    /** Label size. */
+    private static final int LABEL_SIZE = 1;
+
+    /** Maximum size of a matrix. */
+    private static final int SIZE = 5 * 1024 * 1024;
+
+    private static final int MAX_ROW_NUMBER = (int) Math.sqrt(200 * 1024 * 
1024 / 8.0);
+
+    /** The blas used to accelerating speed. */
+    private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
+            (F2jBLAS) F2jBLAS.getInstance();
+
+    /**
+     * Prepare the FastDistanceData, the output is a list of 
FastDistanceMatrixData. As the size of

Review comment:
       The comment seems to be a bit vague regarding what is the relationship 
between the inputs and the outputs of this method. Could you help improve the 
comment?
   
   Could we separate the logic of `converting rows into matrix` from the logic 
of `matrix-related computation`? Then we could move the `matrix-related 
computation` to infra classes that can be shared among algorithms.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+    /** Stores the vector(sparse or dense). */
+    final DenseVector vector;
+
+    /**
+     * Stores some extra info extracted from the vector. For example, if we 
want to save the L1 norm
+     * and L2 norm of the vector, then the two values are viewed as a 
two-dimension label vector.
+     */
+    public DenseVector label;

Review comment:
       It seems that `label` length is always 1. Could we make it a double for 
now?
   
   If we need to make this a vector in the future, could you explain the 
use-case? It looks like whatever normalization setup (e.g. l1, l2) setup we 
use, there will be just one double value representation the normalization value 
of this vector?
   
   And suppose we make this vector, if the length of this vector can be 
determined when `FastDistanceVectorData` is constructed, could this field be 
`public final`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", Vectors.dense(2.0, 3.0)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("m", Vectors.dense(200.1, 300.1)),
+                            Row.of("m", Vectors.dense(200.2, 300.2)),
+                            Row.of("m", Vectors.dense(200.3, 300.3)),
+                            Row.of("m", Vectors.dense(200.4, 300.4)),
+                            Row.of("m", Vectors.dense(200.4, 300.4)),
+                            Row.of("m", Vectors.dense(200.6, 300.6)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.3, 3.2)),
+                            Row.of("f", Vectors.dense(2.3, 3.2)),
+                            Row.of("c", Vectors.dense(2.8, 3.2)),
+                            Row.of("d", Vectors.dense(300., 3.2)),
+                            Row.of("f", Vectors.dense(2.2, 3.2)),
+                            Row.of("e", Vectors.dense(2.4, 3.2)),
+                            Row.of("e", Vectors.dense(2.5, 3.2)),
+                            Row.of("e", Vectors.dense(2.5, 3.2)),
+                            Row.of("f", Vectors.dense(2.1, 3.1))));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)), 
Row.of(Vectors.dense(300, 42))));
+    private Table testData;
+
+    Row[] expectedData =
+            new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m", 
Vectors.dense(300, 42))};
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.STRING())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+        Schema outputSchema =
+                Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testFitAntTransform() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testParamsConstructor() throws Exception {
+        Map<Param<?>, Object> params = new HashMap<>();
+        params.put(HasLabelCol.LABEL_COL, "label");
+        params.put(HasFeaturesCol.FEATURES_COL, "vec");
+        params.put(HasK.K, 4);
+        params.put(HasPredictionCol.PREDICTION_COL, "pred");
+        Knn knn = new Knn(params);
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** test knn as a pipeline stage. */
+    @Test
+    public void testPipeline() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(knn);
+
+        Pipeline pipe = new Pipeline(stages);
+
+        Table result = pipe.fit(trainData).transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** test knn model load and transform. */
+    @Test
+    public void testEstimatorLoadAndSave() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+        knn.save(path);
+
+        Knn loadKnn = Knn.load(path);
+        KnnModel knnModel = loadKnn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** Test knn model load and transform. */
+    @Test
+    public void testModelLoadAndSave() throws Exception {
+        String path = Files.createTempDirectory("").toString();
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+        KnnModel knnModel = knn.fit(trainData);
+        knnModel.save(path);
+        env.execute();
+
+        KnnModel newModel = KnnModel.load(env, path);
+        Table result = newModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** Test Param. */
+    @Test
+    public void testParam() {
+        Knn knnOrigin = new Knn();
+        assertEquals("label", knnOrigin.getLabelCol());
+        assertEquals(10L, knnOrigin.getK().longValue());
+        assertEquals("prediction", knnOrigin.getPredictionCol());
+
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        assertEquals("vec", knn.getFeaturesCol());
+        assertEquals("label", knn.getLabelCol());
+        assertEquals(4L, knn.getK().longValue());
+        assertEquals("pred", knn.getPredictionCol());
+    }
+
+    /** Test model data. */
+    @Test
+    public void testModelData() throws Exception {

Review comment:
       It will be nice to make the tests structure consistent with other 
algorithms (see e.g. `KMeansTest` or tests in other pending PRs). If there is 
good reason to add/remove/modify the test structure, we can apply that to every 
algorithm.
   
   Following this logic, maybe we can rename this test as `testGetModelData` 
and add a `testSetModelData`.
   
   Could you go over all tests in this class and see if there is anything that 
could be improved?
   

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData 
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which 
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the 
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is 
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+    /**
+     * Stores several dense vectors in columns. For example, if the vectorSize 
is n, and matrix
+     * saves m vectors, then the number of rows of <code>vectors</code> is n 
and the number of cols
+     * of <code>vectors</code> is m.
+     */
+    public final DenseMatrix vectors;
+    /**
+     * Save the extra info besides the vector. Each vector is related to one 
row. Thus, for
+     * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+     * FastDistanceMatrixData, the length of <code>rows</code> is equal to the 
number of cols of
+     * <code>matrix</code>. Besides, the order of the rows are the same with 
the vectors.
+     */
+    public final String[] ids;

Review comment:
       Is the comment update-to-date? Does it mean `ids` when it mentions 
`rows`?
   
   I thought the purpose of `FastDistanceMatrixData` is to cache l1/l2 norm. 
Why do we need `ids`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+    /** Stores the vector(sparse or dense). */

Review comment:
       This class supports only `DenseVector` for now. Could we make the 
comment consistent with the implementation?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {

Review comment:
       Why does this class needs to be implements Serializable and Cloneable?
   
   And if this class is supposed to be serialized/deserialized, how do we 
ensure that its serialization/deserialization logic would re-use the 
serialization/deserialization logic of `DenseVector`?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+    /** Stores the vector(sparse or dense). */
+    final DenseVector vector;

Review comment:
       Should this be `public final`?

##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", Vectors.dense(2.0, 3.0)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("m", Vectors.dense(200.1, 300.1)),
+                            Row.of("m", Vectors.dense(200.2, 300.2)),
+                            Row.of("m", Vectors.dense(200.3, 300.3)),
+                            Row.of("m", Vectors.dense(200.4, 300.4)),
+                            Row.of("m", Vectors.dense(200.4, 300.4)),
+                            Row.of("m", Vectors.dense(200.6, 300.6)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.1, 3.1)),
+                            Row.of("f", Vectors.dense(2.3, 3.2)),
+                            Row.of("f", Vectors.dense(2.3, 3.2)),
+                            Row.of("c", Vectors.dense(2.8, 3.2)),
+                            Row.of("d", Vectors.dense(300., 3.2)),
+                            Row.of("f", Vectors.dense(2.2, 3.2)),
+                            Row.of("e", Vectors.dense(2.4, 3.2)),
+                            Row.of("e", Vectors.dense(2.5, 3.2)),
+                            Row.of("e", Vectors.dense(2.5, 3.2)),
+                            Row.of("f", Vectors.dense(2.1, 3.1))));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)), 
Row.of(Vectors.dense(300, 42))));
+    private Table testData;
+
+    Row[] expectedData =
+            new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m", 
Vectors.dense(300, 42))};
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.STRING())
+                        .column("f1", DataTypes.of(DenseVector.class))
+                        .build();
+
+        DataStream<Row> dataStream = env.fromCollection(trainArray);
+        trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+        Schema outputSchema =
+                Schema.newBuilder().column("f0", 
DataTypes.of(DenseVector.class)).build();
+
+        DataStream<Row> predDataStream = env.fromCollection(testArray);
+        testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testFitAntTransform() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testParamsConstructor() throws Exception {
+        Map<Param<?>, Object> params = new HashMap<>();
+        params.put(HasLabelCol.LABEL_COL, "label");
+        params.put(HasFeaturesCol.FEATURES_COL, "vec");
+        params.put(HasK.K, 4);
+        params.put(HasPredictionCol.PREDICTION_COL, "pred");
+        Knn knn = new Knn(params);
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            for (Row exp : expectedData) {
+                assert !exp.getField(1).equals(value.getField(0))
+                        || (exp.getField(0).equals(value.getField(1)));
+            }
+        }
+    }
+
+    /** test knn as a pipeline stage. */
+    @Test
+    public void testPipeline() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(knn);
+
+        Pipeline pipe = new Pipeline(stages);

Review comment:
       It is not clear what is the extra benefit of testing Pipeline here. 
`Pipeline` is tested in `PipelineTest`. Could we remove/refactor this test to 
be consistent with tests of other algorithms?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to