lindong28 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r763682700
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+ /** Row dimension. */
+ public int numRows;
Review comment:
Should this field be `final`? Same for `numCols` and `values`.
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixTypeInfo.java
##########
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg.typeinfo;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.ml.linalg.DenseMatrix;
+
+/** A {@link TypeInformation} for the {@link DenseMatrix} type. */
+public class DenseMatrixTypeInfo extends TypeInformation<DenseMatrix> {
+ private static final long serialVersionUID = 1L;
+
+ public static final DenseMatrixTypeInfo INSTANCE = new
DenseMatrixTypeInfo();
+
+ public DenseMatrixTypeInfo() {}
+
+ @Override
+ public int getArity() {
Review comment:
According to the Java doc of `TypeInformation::getArity`, the return
value is `the number of fields without nesting`. Since `DenseMatrix` has 3
fields, i.e. `numRows`, `numCols` and `values`, should we return `3` here?
Same for `getTotalFields`.
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate
on the matrix it
Review comment:
nits: Could you help improve the Java doc here to mention "this is a
column-major dense matrix"? This information seems useful to specify.
And could we simplify the Java doc by removing `some methods to operate on
the matrix` since this information seems a bit redundant? We don't specify this
information in most other classes such as `DenseVector`. It will be nice to
make the Java doc pattern consistent.
Maybe use the following Java doc which comes from Spark.
```
* Column-major dense matrix.
* The entry values are stored in a single array of doubles with columns
listed in sequence.
```
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate
on the matrix it
+ * represents.
+ */
Review comment:
Should we add `@TypeInfo(DenseMatrixTypeInfoFactory.class)` here? Please
feel free to see `DenseVector` for example.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
##########
@@ -38,4 +43,28 @@ public static RowTypeInfo getRowTypeInfo(ResolvedSchema
schema) {
}
return new RowTypeInfo(types, names);
}
+
+ /** Constructs a RowTypeInfo from the given schema. */
+ public static RowTypeInfo getRowTypeInfo(Schema schema) {
Review comment:
It looks like this this method is used to process the output of
`KnnModelData.getModelSchema`, which generates a `Schema` instead of
`ResolvedSchema`.
Since we have all the information regarding the schema of `KnnModelData`,
could we generate its schema as `ResolvedSchema` directly? This could allow us
to simplify the PR by removing `RowTypeInfo getRowTypeInfo(Schema schema)`.
BTW, it is likely that other algorithms (i.e. `NaiveBayes`,
`LogisticRegression` and `KMeans`) have already solved problems with
other/simpler ways. Maybe we can learn from each other's PR and make the
solution hopefully consistent.
##########
File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasK.java
##########
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared K param. */
+public interface HasK<T> extends WithParams<T> {
Review comment:
The definition of this parameter is pretty much the same as the existing
`KMeansModelParams::K`. Could we update `KMeansModelParams` to use `HasK`?
We can update the Java doc to be something like `The number of clusters or
neighbors of the algorithm`. And if the default value of `KNN` can not be same
as that of `KMeans` (which has `k=2`), we can update one of the algorithm to
override the default value of this parameter in the estimator/model constructor.
What do you think?
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+ /** Row dimension. */
+ public int numRows;
+
+ /** Column dimension. */
+ public int numCols;
+
+ /**
+ * Array for internal storage of elements.
+ *
+ * <p>The matrix data is stored in column major format internally.
+ */
+ public double[] values;
+
+ /**
+ * Constructs an m-by-n matrix of zeros.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of columns.
+ */
+ public DenseMatrix(int numRows, int numCols) {
+ this(numRows, numCols, new double[numRows * numCols]);
+ }
+
+ /**
+ * Constructs a matrix from a 1-D array. The data in the array should
organize in column major.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of cols.
+ * @param values One-dimensional array of doubles.
+ */
+ public DenseMatrix(int numRows, int numCols, double[] values) {
+ assert (values.length == numRows * numCols);
Review comment:
Could we use `Preconditions.checkArgument(..)` here?
Note that `assert` is generally not supposed to be used in production code
and it is not turned on at runtime by default.
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
Review comment:
In the future we probably want to have `SparseMatrix`. And we would like
`SparseMatrix` and `MatrixMatrix` to be accessed by the same set of public APIs.
Could we add the `interface Matrix` in this PR to clearly define how
`DenseMatrix` should be used, so that we can verify this API could support
`SparseMatrix` in the future? Maybe we could learn from Spark regarding how
this interface should be designed.
This could be done similar to how we define the `interface Vector`.
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,47 @@
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/**
+ * DenseMatrix stores dense matrix data and provides some methods to operate
on the matrix it
+ * represents.
+ */
+public class DenseMatrix implements Serializable {
+
+ /** Row dimension. */
+ public int numRows;
+
+ /** Column dimension. */
+ public int numCols;
+
+ /**
+ * Array for internal storage of elements.
+ *
+ * <p>The matrix data is stored in column major format internally.
+ */
+ public double[] values;
+
+ /**
+ * Constructs an m-by-n matrix of zeros.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of columns.
+ */
+ public DenseMatrix(int numRows, int numCols) {
+ this(numRows, numCols, new double[numRows * numCols]);
+ }
+
+ /**
+ * Constructs a matrix from a 1-D array. The data in the array should
organize in column major.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of cols.
+ * @param values One-dimensional array of doubles.
+ */
+ public DenseMatrix(int numRows, int numCols, double[] values) {
+ assert (values.length == numRows * numCols);
Review comment:
Could we use `Preconditions.checkArgument(..)` here?
Note that `assert` is generally not supposed to be used in production code
because it is not turned on at runtime by default.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+ protected Map<Param<?>, Object> params = new HashMap<>();
+ private Table[] modelData;
+
+ /** Constructor. */
+ public KnnModel() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ /**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+ @Override
+ public KnnModel setModelData(Table... modelData) {
+ this.modelData = modelData;
+ return this;
+ }
+
+ /**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+ @Override
+ public Table[] getModelData() {
+ return modelData;
+ }
+
+ /**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+ @Override
+ public Table[] transform(Table... inputs) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+ DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+ final String broadcastKey = "broadcastModelKey";
+ Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(broadcastKey, model);
+ ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+ DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+ String[] resultCols = new String[] {(String)
params.get(KnnModelParams.PREDICTION_COL)};
Review comment:
Should we use `getPredictionCol()` so that the code could be consistent
with how parameters should be get/set?
And `resultCols` has exactly one element, should it be `predictionCol` for
simplicity?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,278 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final String LABEL_COL = "label";
+ private static final String PRED_COL = "pred";
+ private static final String VEC_COL = "vec";
+ List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("f", Vectors.dense(2.0, 3.0)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("m", Vectors.dense(200.1, 300.1)),
+ Row.of("m", Vectors.dense(200.2, 300.2)),
+ Row.of("m", Vectors.dense(200.3, 300.3)),
+ Row.of("m", Vectors.dense(200.4, 300.4)),
+ Row.of("m", Vectors.dense(200.4, 300.4)),
+ Row.of("m", Vectors.dense(200.6, 300.6)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.1, 3.1)),
+ Row.of("f", Vectors.dense(2.3, 3.2)),
+ Row.of("f", Vectors.dense(2.3, 3.2)),
+ Row.of("c", Vectors.dense(2.8, 3.2)),
+ Row.of("d", Vectors.dense(300., 3.2)),
+ Row.of("f", Vectors.dense(2.2, 3.2)),
+ Row.of("e", Vectors.dense(2.4, 3.2)),
+ Row.of("e", Vectors.dense(2.5, 3.2)),
+ Row.of("e", Vectors.dense(2.5, 3.2)),
+ Row.of("f", Vectors.dense(2.1, 3.1))));
+
+ List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("e", Vectors.dense(4.0, 4.1)),
+ Row.of("m", Vectors.dense(300, 42))));
+ private Table testData;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.STRING())
+ .column("f1", DataTypes.of(DenseVector.class))
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + ","
+ VEC_COL);
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, schema).as(LABEL_COL +
"," + VEC_COL);
+ }
+
+ // Executes the graph and returns a list which has true label and predict
label.
+ private static List<Tuple2<String, String>> executeAndCollect(Table
output) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
+
+ DataStream<Tuple2<String, String>> stream =
+ tEnv.toDataStream(output)
+ .map(
+ new MapFunction<Row, Tuple2<String, String>>()
{
+ @Override
+ public Tuple2<String, String> map(Row row)
{
+ return Tuple2.of(
+ (String)
row.getField(LABEL_COL),
+ (String)
row.getField(PRED_COL));
+ }
+ });
+ return IteratorUtils.toList(stream.executeAndCollect());
+ }
+
+ private static void verifyClusteringResult(List<Tuple2<String, String>>
result) {
+ for (Tuple2<String, String> t2 : result) {
+ Assert.assertEquals(t2.f0, t2.f1);
+ }
+ }
+
+ /** Tests Param. */
+ @Test
+ public void testParam() {
+ Knn knnOrigin = new Knn();
Review comment:
Could we make the variable name and test style consistent with other
tests (e.g. `KMeansTest::testParam`)?
These two Knn instances do not seem to have the `original` relationship. And
it seems simpler to just set the parameter of the first instance instead of
creating the 2nd instance.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ /** Constructor. */
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ /**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ ResolvedSchema schema = inputs[0].getResolvedSchema();
+ String[] colNames = schema.getColumnNames().toArray(new String[0]);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+ String labelCol = getLabelCol();
+ String vecCol = getFeaturesCol();
+
+ DataStream<Row> trainData =
+ input.map(
+ (MapFunction<Row, Row>)
+ value -> {
+ Object label =
String.valueOf(value.getField(labelCol));
Review comment:
Could we use concrete class (e.g. `Double/Integer/String`) rather than
`Object` when possible?
Same for other usages of `Object` in `KnnModel`.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+ protected Map<Param<?>, Object> params = new HashMap<>();
+ private Table[] modelData;
+
+ /** Constructor. */
+ public KnnModel() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ /**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+ @Override
+ public KnnModel setModelData(Table... modelData) {
+ this.modelData = modelData;
+ return this;
+ }
+
+ /**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+ @Override
+ public Table[] getModelData() {
+ return modelData;
+ }
+
+ /**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+ @Override
+ public Table[] transform(Table... inputs) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+ DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+ final String broadcastKey = "broadcastModelKey";
+ Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(broadcastKey, model);
+ ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+ DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+ String[] resultCols = new String[] {(String)
params.get(KnnModelParams.PREDICTION_COL)};
+ DataType[] resultTypes = new DataType[] {idType};
+
+ ResolvedSchema outputSchema =
+ TableUtils.getOutputSchema(inputs[0].getResolvedSchema(),
resultCols, resultTypes);
+
+ Function<List<DataStream<?>>, DataStream<Row>> function =
+ dataStreams -> {
+ DataStream stream = dataStreams.get(0);
+ return stream.transform(
+ "mapFunc",
+ TableUtils.getRowTypeInfo(outputSchema),
+ new PredictOperator(
+ inputs[0].getResolvedSchema(),
+ broadcastKey,
+ getK(),
+ getFeaturesCol()));
+ };
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(input), broadcastMap,
function);
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ /** This operator load the model data and do the prediction. */
+ private static class PredictOperator
+ extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+ implements OneInputStreamOperator<Row, Row> {
+
+ private boolean firstEle = true;
+ private final String[] reservedCols;
+ private final String featureCol;
+ private transient KnnModelData modelData;
+ private final Integer topN;
+ private final String broadcastKey;
+
+ public PredictOperator(
+ ResolvedSchema dataSchema, String broadcastKey, int k, String
featureCol) {
+ super(new AbstractRichFunction() {});
+ reservedCols = dataSchema.getColumnNames().toArray(new String[0]);
+ this.topN = k;
+ this.broadcastKey = broadcastKey;
+ this.featureCol = featureCol;
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> streamRecord) throws
Exception {
+ Row value = streamRecord.getValue();
+ output.collect(new StreamRecord<>(map(value)));
+ }
+
+ public Row map(Row row) throws Exception {
+ if (firstEle) {
+
loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey));
+ firstEle = false;
+ }
+ DenseVector vector = (DenseVector) row.getField(featureCol);
+ Tuple2<List<Object>, List<Double>> t2 = findNeighbor(vector, topN,
modelData);
+ Row ret = new Row(reservedCols.length + 1);
+ for (int i = 0; i < reservedCols.length; ++i) {
+ ret.setField(i, row.getField(reservedCols[i]));
+ }
+
+ Tuple2<Object, String> tuple2 = getResultFormat(t2);
+ ret.setField(reservedCols.length, tuple2.f0);
+ return ret;
+ }
+
+ /**
+ * Finds the nearest topN neighbors from whole nodes.
+ *
+ * @param input input node.
+ * @param topN top N.
+ * @return neighbor.
+ */
+ private Tuple2<List<Object>, List<Double>> findNeighbor(
+ Object input, Integer topN, KnnModelData modelData) {
+ PriorityQueue<Tuple2<Double, Object>> priorityQueue =
+ new PriorityQueue<>(modelData.getQueueComparator());
+ search(input, topN, priorityQueue, modelData);
+ List<Object> items = new ArrayList<>();
+ List<Double> metrics = new ArrayList<>();
+ while (!priorityQueue.isEmpty()) {
+ Tuple2<Double, Object> result = priorityQueue.poll();
+ items.add(result.f1);
+ metrics.add(result.f0);
+ }
+ Collections.reverse(items);
+ Collections.reverse(metrics);
+ priorityQueue.clear();
+ return Tuple2.of(items, metrics);
+ }
+
+ /**
+ * @param input input node.
+ * @param topN top N.
+ * @param priorityQueue priority queue.
+ */
+ private void search(
+ Object input,
+ Integer topN,
+ PriorityQueue<Tuple2<Double, Object>> priorityQueue,
+ KnnModelData modelData) {
+ Tuple2<DenseVector, Double> sample = computeNorm((DenseVector)
input);
+ Tuple2<Double, Object> head = null;
+ for (int i = 0; i < modelData.getLength(); i++) {
+ ArrayList<Tuple2<Double, Object>> values =
computeDistance(sample, i);
+ for (Tuple2<Double, Object> currentValue : values) {
+ head = updateQueue(priorityQueue, topN, currentValue,
head);
+ }
+ }
+ }
+
+ /**
+ * Updates queue.
+ *
+ * @param map queue.
+ * @param topN top N.
+ * @param newValue new value.
+ * @param head head value.
+ * @param <T> id type.
+ * @return head value.
+ */
+ private <T> Tuple2<Double, T> updateQueue(
+ PriorityQueue<Tuple2<Double, T>> map,
+ int topN,
+ Tuple2<Double, T> newValue,
+ Tuple2<Double, T> head) {
+ if (map.size() < topN) {
+ map.add(Tuple2.of(newValue.f0, newValue.f1));
+ head = map.peek();
+ } else {
+ if (map.comparator().compare(head, newValue) < 0) {
+ Tuple2<Double, T> peek = map.poll();
+ assert peek != null;
+ peek.f0 = newValue.f0;
+ peek.f1 = newValue.f1;
+ map.add(peek);
+ head = map.peek();
+ }
+ }
+ return head;
+ }
+
+ /**
+ * Computes distance between sample and dictionary vectors.
+ *
+ * @param input sample with l2 norm.
+ * @param index dictionary vectors index.
+ * @return distances.
+ */
+ private ArrayList<Tuple2<Double, Object>> computeDistance(
+ Tuple2<DenseVector, Double> input, Integer index) {
+ Tuple3<DenseMatrix, DenseVector, String[]> data =
modelData.getDictData().get(index);
+
+ DenseMatrix res = calc(input, data);
+ ArrayList<Tuple2<Double, Object>> list = new ArrayList<>(0);
+ String[] curLabels = data.f2;
+ for (int i = 0; i < Objects.requireNonNull(curLabels).length; i++)
{
+ Tuple2<Double, Object> tuple = Tuple2.of(res.values[i],
curLabels[i]);
+ list.add(tuple);
+ }
+ return list;
+ }
+
+ /** The blas used to accelerating speed. */
+ private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
+ (F2jBLAS) F2jBLAS.getInstance();
+
+ /**
+ * Compute distance between sample and dictionary vectors.
+ *
+ * @param left Sample and norm.
+ * @param right Dictionary vectors with row format.
+ * @return a new DenseMatrix which store the result distance.
+ */
+ public DenseMatrix calc(
+ Tuple2<DenseVector, Double> left,
+ Tuple3<DenseMatrix, DenseVector, String[]> right) {
+ DenseMatrix vectors = right.f0;
+ DenseMatrix res = new
DenseMatrix(Objects.requireNonNull(vectors).numCols, 1);
+ DenseVector norm = right.f1;
+ double[] normL2Square = Objects.requireNonNull(norm).values;
+
+ final int m = vectors.numRows;
+ final int n = vectors.numCols;
+ NATIVE_BLAS.dgemv(
+ "T", m, n, -2.0, vectors.values, m, left.f0.toArray(), 1,
0.0, res.values, 1);
+
+ for (int i = 0; i < res.values.length; i++) {
+ res.values[i] = Math.sqrt(Math.abs(res.values[i] + left.f1 +
normL2Square[i]));
+ }
+ return res;
+ }
+
+ /**
+ * Computes norm2 of vector.
+ *
+ * @return Sample with norm2.
+ */
+ public static Tuple2<DenseVector, Double> computeNorm(DenseVector
vector) {
Review comment:
Could this method be private? Same for other method such as `calc(...)`
and `map(...)` defined in this file.
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,278 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+
+import static org.junit.Assert.assertEquals;
+
+/** Knn algorithm test. */
+public class KnnTest {
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final String LABEL_COL = "label";
+ private static final String PRED_COL = "pred";
+ private static final String VEC_COL = "vec";
+ List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of("f", Vectors.dense(2.0, 3.0)),
Review comment:
Hmm.. why is label type `String`? It looks like `int` or `double` would
be more efficient than `String` in terms of the memory and network bandwidth
usage, right?
Currently other algorithms (e.g. KMeans and LogisticRegression) uses
numerical values to represent class/category. Could KNN do the same for
consistency?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,227 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ /** Constructor. */
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ /**
+ * Fits data and produces knn model.
+ *
+ * @param inputs A list of tables
+ * @return Knn model.
+ */
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ ResolvedSchema schema = inputs[0].getResolvedSchema();
+ String[] colNames = schema.getColumnNames().toArray(new String[0]);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+ String labelCol = getLabelCol();
+ String vecCol = getFeaturesCol();
+
+ DataStream<Row> trainData =
+ input.map(
+ (MapFunction<Row, Row>)
+ value -> {
+ Object label =
String.valueOf(value.getField(labelCol));
+ DenseVector vec = (DenseVector)
value.getField(vecCol);
+ return Row.of(label, vec);
+ });
+ DataType idType = null;
+ for (int i = 0; i < colNames.length; i++) {
+ if (labelCol.equalsIgnoreCase(colNames[i])) {
Review comment:
Other than "backward-compatibility", is there any reason we should
support case-insensitive matching?
Could we make the core algorithm code cleaner and simpler by requiring
case-sensitive column name matching?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,412 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.JsonProcessingException;
+
+import dev.ludovic.netlib.blas.F2jBLAS;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+ protected Map<Param<?>, Object> params = new HashMap<>();
+ private Table[] modelData;
+
+ /** Constructor. */
+ public KnnModel() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ /**
+ * Sets model data for knn prediction.
+ *
+ * @param modelData knn model.
+ * @return knn model.
+ */
+ @Override
+ public KnnModel setModelData(Table... modelData) {
+ this.modelData = modelData;
+ return this;
+ }
+
+ /**
+ * Gets model data.
+ *
+ * @return table array.
+ */
+ @Override
+ public Table[] getModelData() {
+ return modelData;
+ }
+
+ /**
+ * Predicts label with knn model.
+ *
+ * @param inputs a list of tables.
+ * @return result.
+ */
+ @Override
+ public Table[] transform(Table... inputs) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+ DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+ final String broadcastKey = "broadcastModelKey";
+ Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+ broadcastMap.put(broadcastKey, model);
+ ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+
+ DataType idType =
+
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+ String[] resultCols = new String[] {(String)
params.get(KnnModelParams.PREDICTION_COL)};
+ DataType[] resultTypes = new DataType[] {idType};
+
+ ResolvedSchema outputSchema =
+ TableUtils.getOutputSchema(inputs[0].getResolvedSchema(),
resultCols, resultTypes);
+
+ Function<List<DataStream<?>>, DataStream<Row>> function =
+ dataStreams -> {
+ DataStream stream = dataStreams.get(0);
+ return stream.transform(
+ "mapFunc",
+ TableUtils.getRowTypeInfo(outputSchema),
+ new PredictOperator(
+ inputs[0].getResolvedSchema(),
+ broadcastKey,
+ getK(),
+ getFeaturesCol()));
+ };
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(input), broadcastMap,
function);
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ /** This operator load the model data and do the prediction. */
+ private static class PredictOperator
+ extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+ implements OneInputStreamOperator<Row, Row> {
+
+ private boolean firstEle = true;
+ private final String[] reservedCols;
+ private final String featureCol;
+ private transient KnnModelData modelData;
+ private final Integer topN;
+ private final String broadcastKey;
+
+ public PredictOperator(
+ ResolvedSchema dataSchema, String broadcastKey, int k, String
featureCol) {
+ super(new AbstractRichFunction() {});
+ reservedCols = dataSchema.getColumnNames().toArray(new String[0]);
+ this.topN = k;
+ this.broadcastKey = broadcastKey;
+ this.featureCol = featureCol;
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> streamRecord) throws
Exception {
+ Row value = streamRecord.getValue();
+ output.collect(new StreamRecord<>(map(value)));
+ }
+
+ public Row map(Row row) throws Exception {
+ if (firstEle) {
+
loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey));
+ firstEle = false;
+ }
+ DenseVector vector = (DenseVector) row.getField(featureCol);
+ Tuple2<List<Object>, List<Double>> t2 = findNeighbor(vector, topN,
modelData);
+ Row ret = new Row(reservedCols.length + 1);
+ for (int i = 0; i < reservedCols.length; ++i) {
+ ret.setField(i, row.getField(reservedCols[i]));
+ }
+
+ Tuple2<Object, String> tuple2 = getResultFormat(t2);
+ ret.setField(reservedCols.length, tuple2.f0);
+ return ret;
+ }
+
+ /**
+ * Finds the nearest topN neighbors from whole nodes.
+ *
+ * @param input input node.
+ * @param topN top N.
+ * @return neighbor.
+ */
+ private Tuple2<List<Object>, List<Double>> findNeighbor(
+ Object input, Integer topN, KnnModelData modelData) {
+ PriorityQueue<Tuple2<Double, Object>> priorityQueue =
+ new PriorityQueue<>(modelData.getQueueComparator());
+ search(input, topN, priorityQueue, modelData);
+ List<Object> items = new ArrayList<>();
+ List<Double> metrics = new ArrayList<>();
+ while (!priorityQueue.isEmpty()) {
+ Tuple2<Double, Object> result = priorityQueue.poll();
+ items.add(result.f1);
+ metrics.add(result.f0);
+ }
+ Collections.reverse(items);
+ Collections.reverse(metrics);
+ priorityQueue.clear();
+ return Tuple2.of(items, metrics);
+ }
+
+ /**
+ * @param input input node.
+ * @param topN top N.
+ * @param priorityQueue priority queue.
+ */
+ private void search(
+ Object input,
+ Integer topN,
+ PriorityQueue<Tuple2<Double, Object>> priorityQueue,
+ KnnModelData modelData) {
+ Tuple2<DenseVector, Double> sample = computeNorm((DenseVector)
input);
+ Tuple2<Double, Object> head = null;
+ for (int i = 0; i < modelData.getLength(); i++) {
+ ArrayList<Tuple2<Double, Object>> values =
computeDistance(sample, i);
+ for (Tuple2<Double, Object> currentValue : values) {
+ head = updateQueue(priorityQueue, topN, currentValue,
head);
+ }
+ }
+ }
+
+ /**
+ * Updates queue.
+ *
+ * @param map queue.
+ * @param topN top N.
+ * @param newValue new value.
+ * @param head head value.
+ * @param <T> id type.
+ * @return head value.
+ */
+ private <T> Tuple2<Double, T> updateQueue(
+ PriorityQueue<Tuple2<Double, T>> map,
+ int topN,
+ Tuple2<Double, T> newValue,
+ Tuple2<Double, T> head) {
+ if (map.size() < topN) {
+ map.add(Tuple2.of(newValue.f0, newValue.f1));
+ head = map.peek();
+ } else {
+ if (map.comparator().compare(head, newValue) < 0) {
+ Tuple2<Double, T> peek = map.poll();
+ assert peek != null;
+ peek.f0 = newValue.f0;
+ peek.f1 = newValue.f1;
+ map.add(peek);
+ head = map.peek();
+ }
+ }
+ return head;
+ }
+
+ /**
+ * Computes distance between sample and dictionary vectors.
+ *
+ * @param input sample with l2 norm.
+ * @param index dictionary vectors index.
+ * @return distances.
+ */
+ private ArrayList<Tuple2<Double, Object>> computeDistance(
+ Tuple2<DenseVector, Double> input, Integer index) {
+ Tuple3<DenseMatrix, DenseVector, String[]> data =
modelData.getDictData().get(index);
+
+ DenseMatrix res = calc(input, data);
+ ArrayList<Tuple2<Double, Object>> list = new ArrayList<>(0);
+ String[] curLabels = data.f2;
+ for (int i = 0; i < Objects.requireNonNull(curLabels).length; i++)
{
+ Tuple2<Double, Object> tuple = Tuple2.of(res.values[i],
curLabels[i]);
+ list.add(tuple);
+ }
+ return list;
+ }
+
+ /** The blas used to accelerating speed. */
+ private static final dev.ludovic.netlib.blas.F2jBLAS NATIVE_BLAS =
Review comment:
We will need to put BLAS-related options in a shared infra class file
such as `BLAS.java`. Both the NaiveBayes and LogisticRegression PR does this.
Maybe you could rebase your PR on e.g. the LogisticRegression PR and add the
`dgemv` method as appropriate?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]