zhipeng93 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r768280582
##########
File path:
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java
##########
@@ -0,0 +1,68 @@
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.api.common.typeinfo.TypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory;
+import org.apache.flink.util.Preconditions;
+
+/**
+ * Column-major dense matrix. The entry values are stored in a single array of
doubles with columns
+ * listed in sequence.
+ */
+@TypeInfo(DenseMatrixTypeInfoFactory.class)
+public class DenseMatrix implements Matrix {
+
+ /** Row dimension. */
+ private final int numRows;
+
+ /** Column dimension. */
+ private final int numCols;
+
+ /**
+ * Array for internal storage of elements.
+ *
+ * <p>The matrix data is stored in column major format internally.
+ */
+ public final double[] values;
+
+ /**
+ * Constructs an m-by-n matrix of zeros.
+ *
+ * @param numRows Number of rows.
+ * @param numCols Number of columns.
+ */
+ public DenseMatrix(int numRows, int numCols) {
+ this(numRows, numCols, new double[numRows * numCols]);
+ }
+
+ /**
+ * Constructs a matrix from a 1-D array. The data in the array should
organize in column major.
Review comment:
nits: should organize ---> be organized
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -52,4 +52,37 @@ public static double norm2(DenseVector x) {
public static void scal(double a, DenseVector x) {
JAVA_BLAS.dscal(x.size(), a, x.values, 1);
}
-}
+
+ /**
+ * y = alpha * matrix * x + beta * y or y = alpha * (matrix^T) * x + beta
* y.
+ *
+ * @param matrix m x n matrix.
Review comment:
Can we update the coment as:
```
@param matrix Dense matrix with size m x n.
@param transMatrix Whether transposes matrix before multiply.
@param x Dense vector with size n.
@param y Dense vector with size m.
```
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> distributedModelData =
prepareModelData(inputData);
+ DataStream<KnnModelData> modelData =
mergeModelData(distributedModelData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares distributed knn model data. Constructs the sample matrix and
computes norm of
+ * features.
+ *
+ * @param inputData Input feature data with label.
+ * @return Distributed knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
Review comment:
It seems to me that the following implementation is more clear:
We rename `prepareModelData()` to `computeL2Norm()` and let it a simple
`MapFunction`.
In this function, we compute the L2 norm of each input data point and output
`Tuple3<DenseVector, Double, Double>` (i.e., feature, label, l2norm) for each
data point.
Then we rename `mergeModelData` to `genModelData`. We can compact the
`features, labels, norms` to `KnnModelData`in the `mapPartitionFunction`
function.
In this case, the semantic of `KnnModelData` is unique and only all of the
input data points can form the final `model data`. Also we can avoid allocating
memory for `KnnModelData` twice.
What do you think?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
Review comment:
protected --> private
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+ protected Map<Param<?>, Object> params = new HashMap<>();
+ private Table modelDataTable;
+
+ public KnnModel() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel setModelData(Table... modelData) {
+ this.modelDataTable = modelData[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+ DataStream<KnnModelData> knnModel =
KnnModelData.getModelDataStream(modelDataTable);
+ final String broadcastModelKey = "broadcastModelKey";
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
BasicTypeInfo.DOUBLE_TYPE_INFO),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(data),
+ Collections.singletonMap(broadcastModelKey, knnModel),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new PredictLabelFunction(
+ broadcastModelKey, getK(),
getFeaturesCol()),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ KnnModelData.getModelDataStream(modelDataTable),
+ path,
+ new KnnModelData.ModelDataEncoder());
+ }
+
+ /**
+ * Loads model data from path.
+ *
+ * @param env Stream execution environment.
+ * @param path Model path.
+ * @return Knn model.
+ */
+ public static KnnModel load(StreamExecutionEnvironment env, String path)
throws IOException {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ KnnModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<KnnModelData> modelData =
+ ReadWriteUtils.loadModelData(env, path, new
KnnModelData.ModelDataDecoder());
+ return model.setModelData(tEnv.fromDataStream(modelData));
+ }
+
+ /** This operator loads model data and predicts result. */
+ private static class PredictLabelFunction extends RichMapFunction<Row,
Row> {
+ private final String featureCol;
+ private KnnModelData knnModelData;
+ private final int k;
+ private final String broadcastKey;
+
+ public PredictLabelFunction(String broadcastKey, int k, String
featureCol) {
+ this.k = k;
+ this.broadcastKey = broadcastKey;
+ this.featureCol = featureCol;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (knnModelData == null) {
+ knnModelData =
+ (KnnModelData)
+
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ }
+ DenseVector feature = (DenseVector) row.getField(featureCol);
+ List<Double> labels = findNearestKLabels(feature);
Review comment:
Can we move the logic of merging labels into `findNearestKLabels()` and
let this method return a single `double` for prediction?
We can probably rename `findNearestKLabels()` to `predict()` or something
like `findMaxProbLabel()`.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> distributedModelData =
prepareModelData(inputData);
+ DataStream<KnnModelData> modelData =
mergeModelData(distributedModelData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares distributed knn model data. Constructs the sample matrix and
computes norm of
+ * features.
+ *
+ * @param inputData Input feature data with label.
+ * @return Distributed knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ List<Tuple2<DenseVector, Double>> dataPoints = new
ArrayList<>(0);
+ for (Tuple2<DenseVector, Double> tuple2 : values) {
+ dataPoints.add(tuple2);
+ }
+ int featureDim = dataPoints.get(0).f0.size();
+ DenseMatrix packedFeatures = new
DenseMatrix(featureDim, dataPoints.size());
+ DenseVector labels = new
DenseVector(dataPoints.size());
+ for (int i = 0; i < dataPoints.size(); ++i) {
+ Tuple2<DenseVector, Double> tuple2 =
dataPoints.get(i);
+ labels.values[i] = tuple2.f1;
+ System.arraycopy(
+ tuple2.f0.values,
+ 0,
+ packedFeatures.values,
+ i * featureDim,
+ featureDim);
+ }
+ DenseVector featureNorms = computeNorm(packedFeatures);
+ if (dataPoints.size() > 0) {
+ out.collect(new KnnModelData(packedFeatures,
featureNorms, labels));
+ }
+ }
+ });
+ }
+
+ /**
+ * Merges knn model data.
+ *
+ * @param distributedModelData Distributed knn model data.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> mergeModelData(
+ DataStream<KnnModelData> distributedModelData) {
+ distributedModelData.getTransformation().setParallelism(1);
+ return DataStreamUtils.mapPartition(
Review comment:
I think we should not set the parallelism of `distributedModelData` as
one.
It is better to set the parallelism of the result stream as one if we want
to parallelize the computing of norms.
What about the following implementation:
```
DataStream<KnnModelData> modelData = DataStreamUtils.mapPartition(...);
modelData.getTransformation().setParallelism(1);
return modelData;
```
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -52,4 +52,37 @@ public static double norm2(DenseVector x) {
public static void scal(double a, DenseVector x) {
JAVA_BLAS.dscal(x.size(), a, x.values, 1);
}
-}
+
+ /**
+ * y = alpha * matrix * x + beta * y or y = alpha * (matrix^T) * x + beta
* y.
+ *
+ * @param matrix m x n matrix.
+ * @param transMatrix Whether transposes matrix before multiply.
+ * @param x dense vector with size n.
+ * @param y dense vector with size m.
+ */
+ public static void gemv(
+ double alpha,
+ DenseMatrix matrix,
+ boolean transMatrix,
+ DenseVector x,
+ double beta,
+ DenseVector y) {
+ Preconditions.checkArgument(
Review comment:
Can you also check the dimension of y?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> distributedModelData =
prepareModelData(inputData);
+ DataStream<KnnModelData> modelData =
mergeModelData(distributedModelData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares distributed knn model data. Constructs the sample matrix and
computes norm of
+ * features.
+ *
+ * @param inputData Input feature data with label.
+ * @return Distributed knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ List<Tuple2<DenseVector, Double>> dataPoints = new
ArrayList<>(0);
+ for (Tuple2<DenseVector, Double> tuple2 : values) {
+ dataPoints.add(tuple2);
+ }
+ int featureDim = dataPoints.get(0).f0.size();
+ DenseMatrix packedFeatures = new
DenseMatrix(featureDim, dataPoints.size());
+ DenseVector labels = new
DenseVector(dataPoints.size());
+ for (int i = 0; i < dataPoints.size(); ++i) {
+ Tuple2<DenseVector, Double> tuple2 =
dataPoints.get(i);
Review comment:
nits: Could we rename `tuple2` to `dataPoint`? This is more consistent
with `dataPoints`.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+ protected Map<Param<?>, Object> params = new HashMap<>();
Review comment:
protected --> private
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> distributedModelData =
prepareModelData(inputData);
+ DataStream<KnnModelData> modelData =
mergeModelData(distributedModelData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
Review comment:
nits: Can we remove `this` here?
In general we only use this.variable if the class member variable name
collide with the function parameter name.
Also, could we rename `params` to `paramMap`? This could make the code style
more consistent, but it is ok to me to leave as it is.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,211 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> distributedModelData =
prepareModelData(inputData);
+ DataStream<KnnModelData> modelData =
mergeModelData(distributedModelData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares distributed knn model data. Constructs the sample matrix and
computes norm of
+ * features.
+ *
+ * @param inputData Input feature data with label.
+ * @return Distributed knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ List<Tuple2<DenseVector, Double>> dataPoints = new
ArrayList<>(0);
+ for (Tuple2<DenseVector, Double> tuple2 : values) {
+ dataPoints.add(tuple2);
+ }
+ int featureDim = dataPoints.get(0).f0.size();
+ DenseMatrix packedFeatures = new
DenseMatrix(featureDim, dataPoints.size());
+ DenseVector labels = new
DenseVector(dataPoints.size());
+ for (int i = 0; i < dataPoints.size(); ++i) {
+ Tuple2<DenseVector, Double> tuple2 =
dataPoints.get(i);
+ labels.values[i] = tuple2.f1;
+ System.arraycopy(
+ tuple2.f0.values,
+ 0,
+ packedFeatures.values,
+ i * featureDim,
+ featureDim);
+ }
+ DenseVector featureNorms = computeNorm(packedFeatures);
+ if (dataPoints.size() > 0) {
+ out.collect(new KnnModelData(packedFeatures,
featureNorms, labels));
+ }
+ }
+ });
+ }
+
+ /**
+ * Merges knn model data.
+ *
+ * @param distributedModelData Distributed knn model data.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> mergeModelData(
+ DataStream<KnnModelData> distributedModelData) {
+ distributedModelData.getTransformation().setParallelism(1);
+ return DataStreamUtils.mapPartition(
+ distributedModelData,
+ new RichMapPartitionFunction<KnnModelData, KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<KnnModelData> values,
Collector<KnnModelData> out) {
+ List<KnnModelData> buffer = new ArrayList<>(1);
+ int totalNumVec = 0;
+ for (KnnModelData data : values) {
+ totalNumVec += data.packedFeatures.numCols();
+ buffer.add(data);
+ }
+ int featureDim =
buffer.get(0).packedFeatures.numRows();
+ DenseMatrix packedFeatures = new
DenseMatrix(featureDim, totalNumVec);
+ DenseVector featureNorms = new
DenseVector(totalNumVec);
+ DenseVector labels = new DenseVector(totalNumVec);
+ int offset = 0;
+ for (KnnModelData data : buffer) {
+ int numVec = data.featureNorms.size();
+ System.arraycopy(
+ data.packedFeatures.values,
+ 0,
+ packedFeatures.values,
+ offset * featureDim,
+ featureDim * numVec);
+ System.arraycopy(
+ data.featureNorms.values,
+ 0,
+ featureNorms.values,
+ offset,
+ numVec);
+ System.arraycopy(data.labels.values, 0,
labels.values, offset, numVec);
+ offset += numVec;
+ }
+ out.collect(new KnnModelData(packedFeatures,
featureNorms, labels));
+ }
+ });
+ }
+
+ /**
+ * For Euclidean distance, distance = sqrt((a - b)^2) = (sqrt(a^2 + b^2 -
2ab)) So it can
Review comment:
Can we polish the comments here and move it to `prepareModelData()`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]