lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r762714568
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+ /**
+ * Stores several dense vectors in columns. For example, if the vectorSize
is n, and matrix
+ * saves m vectors, then the number of rows of <code>vectors</code> is n
and the number of cols
+ * of <code>vectors</code> is m.
+ */
+ public final DenseMatrix vectors;
+ /**
+ * Save the extra info besides the vector. Each vector is related to one
row. Thus, for
+ * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+ * FastDistanceMatrixData, the length of <code>rows</code> is equal to the
number of cols of
+ * <code>matrix</code>. Besides, the order of the rows are the same with
the vectors.
+ */
+ public final String[] ids;
+
+ /**
+ * Stores some extra info extracted from the vector. It's also organized
in columns. For
+ * example, if we want to save the L1 norm and L2 norm of the vector, then
the two values are
+ * viewed as a two-dimension label vector. We organize the norm vectors
together to get the
+ * <code>label</code>. If the number of cols of <code>vectors</code> is m,
then in this case the
+ * dimension of <code>label</code> is 2 * m.
+ */
+ public DenseMatrix label;
+
+ public String[] getIds() {
+ return ids;
+ }
+
+ /**
+ * Constructor, initialize the vector data and extra info.
+ *
+ * @param vectors DenseMatrix which saves vectors in columns.
+ * @param ids extra info besides the vector.
+ */
+ public FastDistanceMatrixData(DenseMatrix vectors, String[] ids) {
+ this.ids = ids;
+ Preconditions.checkNotNull(vectors, "DenseMatrix should not be null!");
+ if (null != ids) {
+ Preconditions.checkArgument(
+ vectors.numCols() == ids.length,
+ "The column number of DenseMatrix must be equal to the
rows array length!");
+ }
+ this.vectors = vectors;
+ }
+
+ /**
+ * serialization of FastDistanceMatrixData.
+ *
+ * @return json string.
+ */
+ @Override
+ public String toString() {
Review comment:
In general `toString()` is used for debugging/informational purpose.
Could we move the serialization/deserialization logic to dedicated class,
similar to `DenseVectorTypeInfoFactory`?
We probably need to have `DenseMatrixTypeInfoFactory` to
serialize/deserialize `DenseMatrix`, similar to how
`DenseVectorTypeInfoFactory` can be used to serialize/deserialize
`DenseVector`. Then we can re-use this logic to serialize/deserialize
`FastDistanceMatrixData`, which is effectively a few `DenseMatrix`.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+ /**
+ * Stores several dense vectors in columns. For example, if the vectorSize
is n, and matrix
+ * saves m vectors, then the number of rows of <code>vectors</code> is n
and the number of cols
+ * of <code>vectors</code> is m.
+ */
+ public final DenseMatrix vectors;
+ /**
+ * Save the extra info besides the vector. Each vector is related to one
row. Thus, for
+ * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+ * FastDistanceMatrixData, the length of <code>rows</code> is equal to the
number of cols of
+ * <code>matrix</code>. Besides, the order of the rows are the same with
the vectors.
+ */
+ public final String[] ids;
+
+ /**
+ * Stores some extra info extracted from the vector. It's also organized
in columns. For
+ * example, if we want to save the L1 norm and L2 norm of the vector, then
the two values are
+ * viewed as a two-dimension label vector. We organize the norm vectors
together to get the
+ * <code>label</code>. If the number of cols of <code>vectors</code> is m,
then in this case the
+ * dimension of <code>label</code> is 2 * m.
+ */
+ public DenseMatrix label;
+
+ public String[] getIds() {
Review comment:
In general we don't need `getXXX` for public final fields. Could this be
removed?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistance.java
##########
@@ -0,0 +1,192 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * FastDistance is an accelerated distance calculating method. It use matrix
vector operation to
+ * improve speed of distance calculating.
+ *
+ * <p>The distance type in this class is euclidean distance:
+ *
+ * <p>https://en.wikipedia.org/wiki/Euclidean_distance
+ */
+public class FastDistance implements Serializable {
Review comment:
Why does this class needs to be Serializable? Should all methods of this
class be static?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistance.java
##########
@@ -0,0 +1,192 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+/**
+ * FastDistance is an accelerated distance calculating method. It use matrix
vector operation to
+ * improve speed of distance calculating.
+ *
+ * <p>The distance type in this class is euclidean distance:
+ *
+ * <p>https://en.wikipedia.org/wiki/Euclidean_distance
+ */
+public class FastDistance implements Serializable {
+ /** Label size. */
+ private static final int LABEL_SIZE = 1;
+
+ /** Maximum size of a matrix. */
+ private static final int SIZE = 5 * 1024 * 1024;
+
+ private static final int MAX_ROW_NUMBER = (int) Math.sqrt(200 * 1024 *
1024 / 8.0);
+
+ /** The blas used to accelerating speed. */
+ private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
+ (F2jBLAS) F2jBLAS.getInstance();
+
+ /**
+ * Prepare the FastDistanceData, the output is a list of
FastDistanceMatrixData. As the size of
Review comment:
The comment seems to be a bit vague regarding what is the relationship
between the inputs and the outputs of this method. Could you help improve the
comment?
Could we separate the logic of `converting rows into matrix` from the logic
of `matrix-related computation`? Then we could move the `matrix-related
computation` to infra classes that can be shared among algorithms.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+ /** Stores the vector(sparse or dense). */
+ final DenseVector vector;
+
+ /**
+ * Stores some extra info extracted from the vector. For example, if we
want to save the L1 norm
+ * and L2 norm of the vector, then the two values are viewed as a
two-dimension label vector.
+ */
+ public DenseVector label;
Review comment:
It seems that `label` length is always 1. Could we make it a double for
now?
If we need to make this a vector in the future, could you explain the
use-case? It looks like whatever normalization setup (e.g. l1, l2) setup we
use, there will be just one double value representation the normalization value
of this vector?
And suppose we make this vector, if the length of this vector can be
determined when `FastDistanceVectorData` is constructed, could this field be
`public final`?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+
+ List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("f", Vectors.dense(2.0, 3.0)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("m", Vectors.dense(200.1, 300.1)),
+ Row.of("m", Vectors.dense(200.2, 300.2)),
+ Row.of("m", Vectors.dense(200.3, 300.3)),
+ Row.of("m", Vectors.dense(200.4, 300.4)),
+ Row.of("m", Vectors.dense(200.4, 300.4)),
+ Row.of("m", Vectors.dense(200.6, 300.6)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.3, 3.2)),
+ Row.of("f", Vectors.dense(2.3, 3.2)),
+ Row.of("c", Vectors.dense(2.8, 3.2)),
+ Row.of("d", Vectors.dense(300., 3.2)),
+ Row.of("f", Vectors.dense(2.2, 3.2)),
+ Row.of("e", Vectors.dense(2.4, 3.2)),
+ Row.of("e", Vectors.dense(2.5, 3.2)),
+ Row.of("e", Vectors.dense(2.5, 3.2)),
+ Row.of("f", Vectors.dense(2.1, 3.1))));
+
+ List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)),
Row.of(Vectors.dense(300, 42))));
+ private Table testData;
+
+ Row[] expectedData =
+ new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m",
Vectors.dense(300, 42))};
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.STRING())
+ .column("f1", DataTypes.of(DenseVector.class))
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+ Schema outputSchema =
+ Schema.newBuilder().column("f0",
DataTypes.of(DenseVector.class)).build();
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+ }
+
+ /** test knn Estimator. */
+ @Test
+ public void testFitAntTransform() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+
+ KnnModel knnModel = knn.fit(trainData);
+ Table result = knnModel.transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** test knn Estimator. */
+ @Test
+ public void testParamsConstructor() throws Exception {
+ Map<Param<?>, Object> params = new HashMap<>();
+ params.put(HasLabelCol.LABEL_COL, "label");
+ params.put(HasFeaturesCol.FEATURES_COL, "vec");
+ params.put(HasK.K, 4);
+ params.put(HasPredictionCol.PREDICTION_COL, "pred");
+ Knn knn = new Knn(params);
+
+ KnnModel knnModel = knn.fit(trainData);
+ Table result = knnModel.transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** test knn as a pipeline stage. */
+ @Test
+ public void testPipeline() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+
+ List<Stage<?>> stages = new ArrayList<>();
+ stages.add(knn);
+
+ Pipeline pipe = new Pipeline(stages);
+
+ Table result = pipe.fit(trainData).transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** test knn model load and transform. */
+ @Test
+ public void testEstimatorLoadAndSave() throws Exception {
+ String path = Files.createTempDirectory("").toString();
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+ knn.save(path);
+
+ Knn loadKnn = Knn.load(path);
+ KnnModel knnModel = loadKnn.fit(trainData);
+ Table result = knnModel.transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** Test knn model load and transform. */
+ @Test
+ public void testModelLoadAndSave() throws Exception {
+ String path = Files.createTempDirectory("").toString();
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+ KnnModel knnModel = knn.fit(trainData);
+ knnModel.save(path);
+ env.execute();
+
+ KnnModel newModel = KnnModel.load(env, path);
+ Table result = newModel.transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** Test Param. */
+ @Test
+ public void testParam() {
+ Knn knnOrigin = new Knn();
+ assertEquals("label", knnOrigin.getLabelCol());
+ assertEquals(10L, knnOrigin.getK().longValue());
+ assertEquals("prediction", knnOrigin.getPredictionCol());
+
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+
+ assertEquals("vec", knn.getFeaturesCol());
+ assertEquals("label", knn.getLabelCol());
+ assertEquals(4L, knn.getK().longValue());
+ assertEquals("pred", knn.getPredictionCol());
+ }
+
+ /** Test model data. */
+ @Test
+ public void testModelData() throws Exception {
Review comment:
It will be nice to make the tests structure consistent with other
algorithms (see e.g. `KMeansTest` or tests in other pending PRs). If there is
good reason to add/remove/modify the test structure, we can apply that to every
algorithm.
Following this logic, maybe we can rename this test as `testGetModelData`
and add a `testSetModelData`.
Could you go over all tests in this class and see if there is anything that
could be improved?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceMatrixData.java
##########
@@ -0,0 +1,122 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.util.Preconditions;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Save the data for calculating distance fast. The FastDistanceMatrixData
saves several dense
+ * vectors in a single matrix. The vectors are organized in columns, which
means each column is a
+ * single vector. For example, vec1: 0,1,2, vec2: 3,4,5, vec3: 6,7,8, then the
data in matrix is
+ * organized as: vec1,vec2,vec3. And the data array in <code>vectors</code> is
{0,1,2,3,4,5,6,7,8}.
+ */
+public class FastDistanceMatrixData implements Serializable {
+
+ /**
+ * Stores several dense vectors in columns. For example, if the vectorSize
is n, and matrix
+ * saves m vectors, then the number of rows of <code>vectors</code> is n
and the number of cols
+ * of <code>vectors</code> is m.
+ */
+ public final DenseMatrix vectors;
+ /**
+ * Save the extra info besides the vector. Each vector is related to one
row. Thus, for
+ * FastDistanceVectorData, the length of <code>rows</code> is one. And for
+ * FastDistanceMatrixData, the length of <code>rows</code> is equal to the
number of cols of
+ * <code>matrix</code>. Besides, the order of the rows are the same with
the vectors.
+ */
+ public final String[] ids;
Review comment:
Is the comment update-to-date? Does it mean `ids` when it mentions
`rows`?
I thought the purpose of `FastDistanceMatrixData` is to cache l1/l2 norm.
Why do we need `ids`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+ /** Stores the vector(sparse or dense). */
Review comment:
This class supports only `DenseVector` for now. Could we make the
comment consistent with the implementation?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
Review comment:
Why does this class needs to be implements Serializable and Cloneable?
And if this class is supposed to be serialized/deserialized, how do we
ensure that its serialization/deserialization logic would re-use the
serialization/deserialization logic of `DenseVector`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/FastDistanceVectorData.java
##########
@@ -0,0 +1,28 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.Serializable;
+
+/** Save the data for calculating distance fast. The FastDistanceMatrixData */
+public class FastDistanceVectorData implements Serializable, Cloneable {
+ /** Stores the vector(sparse or dense). */
+ final DenseVector vector;
Review comment:
Should this be `public final`?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,285 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+
+ List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("f", Vectors.dense(2.0, 3.0)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("m", Vectors.dense(200.1, 300.1)),
+ Row.of("m", Vectors.dense(200.2, 300.2)),
+ Row.of("m", Vectors.dense(200.3, 300.3)),
+ Row.of("m", Vectors.dense(200.4, 300.4)),
+ Row.of("m", Vectors.dense(200.4, 300.4)),
+ Row.of("m", Vectors.dense(200.6, 300.6)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.3, 3.2)),
+ Row.of("f", Vectors.dense(2.3, 3.2)),
+ Row.of("c", Vectors.dense(2.8, 3.2)),
+ Row.of("d", Vectors.dense(300., 3.2)),
+ Row.of("f", Vectors.dense(2.2, 3.2)),
+ Row.of("e", Vectors.dense(2.4, 3.2)),
+ Row.of("e", Vectors.dense(2.5, 3.2)),
+ Row.of("e", Vectors.dense(2.5, 3.2)),
+ Row.of("f", Vectors.dense(2.1, 3.1))));
+
+ List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(Row.of(Vectors.dense(4.0, 4.1)),
Row.of(Vectors.dense(300, 42))));
+ private Table testData;
+
+ Row[] expectedData =
+ new Row[] {Row.of("e", Vectors.dense(4.0, 4.1)), Row.of("m",
Vectors.dense(300, 42))};
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.STRING())
+ .column("f1", DataTypes.of(DenseVector.class))
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as("label, vec");
+
+ Schema outputSchema =
+ Schema.newBuilder().column("f0",
DataTypes.of(DenseVector.class)).build();
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, outputSchema).as("vec");
+ }
+
+ /** test knn Estimator. */
+ @Test
+ public void testFitAntTransform() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+
+ KnnModel knnModel = knn.fit(trainData);
+ Table result = knnModel.transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** test knn Estimator. */
+ @Test
+ public void testParamsConstructor() throws Exception {
+ Map<Param<?>, Object> params = new HashMap<>();
+ params.put(HasLabelCol.LABEL_COL, "label");
+ params.put(HasFeaturesCol.FEATURES_COL, "vec");
+ params.put(HasK.K, 4);
+ params.put(HasPredictionCol.PREDICTION_COL, "pred");
+ Knn knn = new Knn(params);
+
+ KnnModel knnModel = knn.fit(trainData);
+ Table result = knnModel.transform(testData)[0];
+
+ DataStream<Row> output = tEnv.toDataStream(result);
+
+ List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+ for (Row value : rows) {
+ for (Row exp : expectedData) {
+ assert !exp.getField(1).equals(value.getField(0))
+ || (exp.getField(0).equals(value.getField(1)));
+ }
+ }
+ }
+
+ /** test knn as a pipeline stage. */
+ @Test
+ public void testPipeline() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("label")
+ .setFeaturesCol("vec")
+ .setK(4)
+ .setPredictionCol("pred");
+
+ List<Stage<?>> stages = new ArrayList<>();
+ stages.add(knn);
+
+ Pipeline pipe = new Pipeline(stages);
Review comment:
It is not clear what is the extra benefit of testing Pipeline here.
`Pipeline` is tested in `PipelineTest`. Could we remove/refactor this test to
be consistent with tests of other algorithms?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]