zhipeng93 commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r766506565
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -22,13 +22,84 @@
/** A utility class that provides BLAS routines over matrices and vectors. */
public class BLAS {
+
/** For level-1 function dspmv, use javaBLAS for better performance. */
private static final dev.ludovic.netlib.BLAS JAVA_BLAS =
dev.ludovic.netlib.JavaBLAS.getInstance();
- /** y += a * x . */
+ /**
+ * \sum_i |x_i| .
+ *
+ * @param x x
+ * @return \sum_i |x_i|
+ */
+ public static double asum(DenseVector x) {
+ return JAVA_BLAS.dasum(x.size(), x.values, 0, 1);
+ }
+
+ /**
+ * y += a * x .
+ *
+ * @param a a
+ * @param x x
+ * @param y y
+ */
public static void axpy(double a, DenseVector x, DenseVector y) {
Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
}
+
+ /**
+ * x \cdot y .
+ *
+ * @param x x
+ * @param y y
+ * @return x \cdot y
+ */
+ public static double dot(DenseVector x, DenseVector y) {
+ Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
+ return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
+ }
+
+ /**
+ * \sqrt(\sum_i x_i * x_i) .
+ *
+ * @param x x
+ * @return \sqrt(\sum_i x_i * x_i)
+ */
+ public static double norm2(DenseVector x) {
+ return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
+ }
+
+ /**
+ * x = x * a .
+ *
+ * @param a a
+ * @param x x
+ */
+ public static void scal(double a, DenseVector x) {
+ JAVA_BLAS.dscal(x.size(), a, x.values, 1);
+ }
+
+ /**
+ * y := alpha * A * x + beta * y.
+ *
+ * @param matA m x n matrix A.
+ * @param transA transform matrix or not.
+ * @param x dense vector with size n.
+ * @param y dense vector with size m.
+ */
+ public static void gemv(
Review comment:
Can you add a unit test for this method?
Also can you improve the java docs here? e.g.,
`y := alpha * A * x + beta * y.` --> `y = alpha * A * x + beta * y or y =
alpha * (A^T) * x + beta * y `
Also can you explain the `transA` more clearly?
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Matrix.java
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import java.io.Serializable;
+
+/** A matrix of double values. */
+public interface Matrix extends Serializable {
+
+ /** Gets number of rows. */
+ int numRows();
+
+ /** Gets number of columns. */
+ int numCols();
+
+ /** Gets value of the (i,j) element. */
+ double get(int i, int j);
+
+ /** Converts the instance to a double array. */
+ double[] toArray();
Review comment:
Can we remove this method for now? Given that it is never used.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
+ Iterable<Tuple2<DenseVector, Double>> trainData) {
+ List<Tuple2<DenseVector, Double>> buffer = new ArrayList<>(0);
+ int vecSize = -1;
+ for (Tuple2<DenseVector, Double> tuple2 : trainData) {
+ if (vecSize == -1) {
+ vecSize = tuple2.f0.size();
+ }
+ buffer.add(tuple2);
+ }
+ if (vecSize == -1) {
+ return null;
+ }
+ DenseMatrix matrix = new DenseMatrix(vecSize, buffer.size());
+ DenseVector label = new DenseVector(buffer.size());
+ for (int i = 0; i < buffer.size(); ++i) {
+ Tuple2<DenseVector, Double> tuple2 = buffer.get(i);
+ label.values[i] = tuple2.f1;
+ double[] vectorData = tuple2.f0.toArray();
+ double[] matrixData = matrix.values;
+ System.arraycopy(vectorData, 0, matrixData, i * vecSize, vecSize);
+ }
+ DenseVector norm = computeNorm(matrix);
Review comment:
How about `DenseVector featureNorms =
computeFeatureNorms(packedFeatures);`?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final List<Row> trainArray =
Review comment:
Can we have another name for `trainArray`, given that it is indeed a
list?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(2.0, 3.0), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(200.1, 300.1), 2.0),
+ Row.of(Vectors.dense(200.2, 300.2), 2.0),
+ Row.of(Vectors.dense(200.3, 300.3), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.6, 300.6), 2.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.8, 3.2), 3.0),
+ Row.of(Vectors.dense(300., 3.2), 4.0),
+ Row.of(Vectors.dense(2.2, 3.2), 1.0),
+ Row.of(Vectors.dense(2.4, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0)));
+
+ private static final List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(4.0, 4.1), 5.0),
+ Row.of(Vectors.dense(300, 42), 2.0)));
+ private Table testData;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.of(DenseVector.class))
+ .column("f1", DataTypes.DOUBLE())
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as("features",
"label");
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, schema).as("features",
"label");
+ }
+
+ private static List<Tuple2<Double, Double>> executeAndCollect(
+ Table output, String labelCol, String predictionCol) throws
Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
+
+ DataStream<Tuple2<Double, Double>> stream =
+ tEnv.toDataStream(output)
+ .map(
+ new MapFunction<Row, Tuple2<Double, Double>>()
{
+ @Override
+ public Tuple2<Double, Double> map(Row row)
{
+ return Tuple2.of(
+ (Double)
row.getField(labelCol),
+ (Double)
row.getField(predictionCol));
+ }
+ });
+ return IteratorUtils.toList(stream.executeAndCollect());
+ }
+
+ private static void verifyClusteringResult(List<Tuple2<Double, Double>>
result) {
+ for (Tuple2<Double, Double> t2 : result) {
+ Assert.assertEquals(t2.f0, t2.f1);
+ }
+ }
+
+ @Test
+ public void testParam() {
+ Knn knn = new Knn();
+ assertEquals("features", knn.getFeaturesCol());
+ assertEquals("label", knn.getLabelCol());
+ assertEquals(10L, knn.getK().longValue());
+ assertEquals("prediction", knn.getPredictionCol());
+
+ knn.setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setK(4)
+ .setPredictionCol("test_prediction");
+
+ assertEquals("test_features", knn.getFeaturesCol());
+ assertEquals("test_label", knn.getLabelCol());
+ assertEquals(4L, knn.getK().longValue());
+ assertEquals("test_prediction", knn.getPredictionCol());
+ }
+
+ @Test
+ public void testFeaturePredictionParam() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setK(4)
+ .setPredictionCol("test_prediction");
+ KnnModel model = knn.fit(trainData.as("test_features, test_label"));
+ Table output = model.transform(testData.as("test_features,
test_label"))[0];
+
+ assertEquals(
+ Arrays.asList("test_features", "test_label",
"test_prediction"),
+ output.getResolvedSchema().getColumnNames());
+
+ List<Tuple2<Double, Double>> result =
+ executeAndCollect(output, "test_label", "test_prediction");
+ verifyClusteringResult(result);
+ }
+
+ @Test
+ public void testFewerDistinctPointsThanCluster() throws Exception {
+ Knn knn = new Knn().setK(4);
+ KnnModel model = knn.fit(testData);
+ Table output = model.transform(testData)[0];
+ executeAndCollect(output, "label", "prediction");
Review comment:
can we replace `executeAndCollect(output, "label", "prediction");` with
`executeAndCollect(output, knn.getLabelCol(), knn.getPredictionCol())` here?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
Review comment:
Can you move the logic of this `prepareData` to `prepareModelData`?This
could be more clear
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
Review comment:
It seems that the KnnModelData here is a distributed one. Should we
merge them as one here? Since any way you need to construct it as one during
inference.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
+ Iterable<Tuple2<DenseVector, Double>> trainData) {
+ List<Tuple2<DenseVector, Double>> buffer = new ArrayList<>(0);
+ int vecSize = -1;
+ for (Tuple2<DenseVector, Double> tuple2 : trainData) {
Review comment:
Can you extract `vecSize` out from the loop? Moreover, is `featureDim` a
better name?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
+ Iterable<Tuple2<DenseVector, Double>> trainData) {
+ List<Tuple2<DenseVector, Double>> buffer = new ArrayList<>(0);
Review comment:
Is `dataPoints` a better name for `buffer`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
+ Iterable<Tuple2<DenseVector, Double>> trainData) {
+ List<Tuple2<DenseVector, Double>> buffer = new ArrayList<>(0);
+ int vecSize = -1;
+ for (Tuple2<DenseVector, Double> tuple2 : trainData) {
+ if (vecSize == -1) {
+ vecSize = tuple2.f0.size();
+ }
+ buffer.add(tuple2);
+ }
+ if (vecSize == -1) {
+ return null;
+ }
+ DenseMatrix matrix = new DenseMatrix(vecSize, buffer.size());
+ DenseVector label = new DenseVector(buffer.size());
+ for (int i = 0; i < buffer.size(); ++i) {
+ Tuple2<DenseVector, Double> tuple2 = buffer.get(i);
+ label.values[i] = tuple2.f1;
+ double[] vectorData = tuple2.f0.toArray();
+ double[] matrixData = matrix.values;
+ System.arraycopy(vectorData, 0, matrixData, i * vecSize, vecSize);
+ }
+ DenseVector norm = computeNorm(matrix);
+ return Tuple3.of(matrix, norm, label);
+ }
+
+ /**
+ * For Euclidean distance, distance = sqrt((a - b)^2) = (sqrt(a^2 + b^2 -
2ab)) So it can
+ * pre-calculate the L2 norm square of the vector, and when calculating
the distance with
+ * another vector, only dot product is calculated.
+ *
+ * @param matrix Input matrix.
+ * @return Norm vector, vector length is the number of samples.
+ */
+ private static DenseVector computeNorm(DenseMatrix matrix) {
Review comment:
Can we also name `matrix` here to `packedFeatures`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
+ Iterable<Tuple2<DenseVector, Double>> trainData) {
+ List<Tuple2<DenseVector, Double>> buffer = new ArrayList<>(0);
+ int vecSize = -1;
+ for (Tuple2<DenseVector, Double> tuple2 : trainData) {
+ if (vecSize == -1) {
+ vecSize = tuple2.f0.size();
+ }
+ buffer.add(tuple2);
+ }
+ if (vecSize == -1) {
+ return null;
+ }
+ DenseMatrix matrix = new DenseMatrix(vecSize, buffer.size());
Review comment:
Is `packedFeatures` a better name for `matrix`?
`label` --> `labels`
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+
+/** A Model which classifies data using the model data computed by {@link
Knn}. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+ protected Map<Param<?>, Object> params = new HashMap<>();
+ private Table modelDataTable;
+
+ public KnnModel() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel setModelData(Table... modelData) {
+ this.modelDataTable = modelData[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+ DataStream<KnnModelData> knnModel =
KnnModelData.getModelDataStream(modelDataTable);
+ final String broadcastModelKey = "broadcastModelKey";
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
BasicTypeInfo.DOUBLE_TYPE_INFO),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getPredictionCol()));
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(data),
+ Collections.singletonMap(broadcastModelKey, knnModel),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new PredictLabelFunction(
+ broadcastModelKey, getK(),
getFeaturesCol()),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ KnnModelData.getModelDataStream(modelDataTable),
+ path,
+ new KnnModelData.ModelDataEncoder(),
+ 0);
+ }
+
+ /**
+ * Loads model data from path.
+ *
+ * @param env Stream execution environment.
+ * @param path Model path.
+ * @return Knn model.
+ */
+ public static KnnModel load(StreamExecutionEnvironment env, String path)
throws IOException {
+ KnnModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new
ModelDataDecoder(), 0);
+ return model.setModelData(modelDataTable);
+ }
+
+ /** This operator loads model data and predicts result. */
+ private static class PredictLabelFunction extends RichMapFunction<Row,
Row> {
+ private final String featureCol;
+ private KnnModelData knnModelData;
+ private final int k;
+ private final String broadcastKey;
+
+ public PredictLabelFunction(String broadcastKey, int k, String
featureCol) {
+ this.k = k;
+ this.broadcastKey = broadcastKey;
+ this.featureCol = featureCol;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (knnModelData == null) {
+
loadModel(getRuntimeContext().getBroadcastVariable(broadcastKey));
+ }
+ DenseVector vector = (DenseVector) row.getField(featureCol);
+ Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector);
+ return Row.join(row, Row.of(getResult(tuple2)));
+ }
+
+ /** Finds the nearest k neighbors from whole vectors in matrix format.
*/
+ private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector
input) {
Review comment:
Could we simplify the logic here and combine some of the methods into
one? (`findNeighbor`, `search`, `updateQueue`, `computeDistance` and
`getResult`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##########
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link KnnModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to
Datastream, and classes
+ * to save/load model data.
+ */
+public class KnnModelData {
+
+ /**
+ * Three element of tuple3 is: 1. Matrix data constructed by sample
vectors. 2. L2 norm of
+ * sample vector. 3. Labels of sample.
+ */
+ public Tuple3<DenseMatrix, DenseVector, DenseVector> modelData;
Review comment:
Can we replace the `Tuple3` here with three class members, e.g.,
`packedFeatures, featureNorms, labels`?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the KNN algorithm.
+ *
+ * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+ protected Map<Param<?>, Object> params = new HashMap<>();
+
+ public Knn() {
+ ParamUtils.initializeMapWithDefaultValues(params, this);
+ }
+
+ @Override
+ public KnnModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ /* Tuple2 : <sampleVector, label> */
+ DataStream<Tuple2<DenseVector, Double>> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<DenseVector,
Double>>() {
+ @Override
+ public Tuple2<DenseVector, Double> map(Row
value) {
+ Double label = (Double)
value.getField(getLabelCol());
+ DenseVector feature =
+ (DenseVector)
value.getField(getFeaturesCol());
+ return Tuple2.of(feature, label);
+ }
+ });
+ DataStream<KnnModelData> modelData = prepareModelData(inputData);
+ KnnModel model = new
KnnModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return this.params;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static Knn load(StreamExecutionEnvironment env, String path) throws
IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /**
+ * Prepares knn model data. Constructs the sample matrix and computes norm
of vectors.
+ *
+ * @param inputData Input vector data with label.
+ * @return Knn model.
+ */
+ private static DataStream<KnnModelData> prepareModelData(
+ DataStream<Tuple2<DenseVector, Double>> inputData) {
+ return DataStreamUtils.mapPartition(
+ inputData,
+ new RichMapPartitionFunction<Tuple2<DenseVector, Double>,
KnnModelData>() {
+ @Override
+ public void mapPartition(
+ Iterable<Tuple2<DenseVector, Double>> values,
+ Collector<KnnModelData> out) {
+ Tuple3<DenseMatrix, DenseVector, DenseVector> model =
prepareData(values);
+ if (model != null) {
+ out.collect(new KnnModelData(model));
+ }
+ }
+ },
+ TypeInformation.of(KnnModelData.class));
+ }
+
+ /**
+ * Prepares knn model data, the output is a Tuple3, which includes matrix,
vector norms and
+ * labels.
+ *
+ * @param trainData Input train data.
+ * @return Model data in format of tuple3.
+ */
+ private static Tuple3<DenseMatrix, DenseVector, DenseVector> prepareData(
+ Iterable<Tuple2<DenseVector, Double>> trainData) {
+ List<Tuple2<DenseVector, Double>> buffer = new ArrayList<>(0);
+ int vecSize = -1;
+ for (Tuple2<DenseVector, Double> tuple2 : trainData) {
+ if (vecSize == -1) {
+ vecSize = tuple2.f0.size();
+ }
+ buffer.add(tuple2);
+ }
+ if (vecSize == -1) {
+ return null;
+ }
+ DenseMatrix matrix = new DenseMatrix(vecSize, buffer.size());
+ DenseVector label = new DenseVector(buffer.size());
+ for (int i = 0; i < buffer.size(); ++i) {
+ Tuple2<DenseVector, Double> tuple2 = buffer.get(i);
+ label.values[i] = tuple2.f1;
+ double[] vectorData = tuple2.f0.toArray();
+ double[] matrixData = matrix.values;
+ System.arraycopy(vectorData, 0, matrixData, i * vecSize, vecSize);
+ }
+ DenseVector norm = computeNorm(matrix);
+ return Tuple3.of(matrix, norm, label);
+ }
+
+ /**
+ * For Euclidean distance, distance = sqrt((a - b)^2) = (sqrt(a^2 + b^2 -
2ab)) So it can
+ * pre-calculate the L2 norm square of the vector, and when calculating
the distance with
+ * another vector, only dot product is calculated.
+ *
+ * @param matrix Input matrix.
+ * @return Norm vector, vector length is the number of samples.
+ */
+ private static DenseVector computeNorm(DenseMatrix matrix) {
+ DenseVector norm = new DenseVector(new double[matrix.numCols()]);
+ double[] normValues = norm.values;
+ double[] matrixData = matrix.values;
+ Arrays.fill(normValues, 0.0);
Review comment:
Probably we can remove `Arrays.fill()`
##########
File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
##########
@@ -22,13 +22,84 @@
/** A utility class that provides BLAS routines over matrices and vectors. */
public class BLAS {
+
/** For level-1 function dspmv, use javaBLAS for better performance. */
private static final dev.ludovic.netlib.BLAS JAVA_BLAS =
dev.ludovic.netlib.JavaBLAS.getInstance();
- /** y += a * x . */
+ /**
+ * \sum_i |x_i| .
+ *
+ * @param x x
+ * @return \sum_i |x_i|
+ */
+ public static double asum(DenseVector x) {
+ return JAVA_BLAS.dasum(x.size(), x.values, 0, 1);
+ }
+
+ /**
+ * y += a * x .
+ *
+ * @param a a
+ * @param x x
+ * @param y y
+ */
public static void axpy(double a, DenseVector x, DenseVector y) {
Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
}
+
+ /**
+ * x \cdot y .
+ *
+ * @param x x
+ * @param y y
+ * @return x \cdot y
+ */
+ public static double dot(DenseVector x, DenseVector y) {
+ Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
+ return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1);
+ }
+
+ /**
+ * \sqrt(\sum_i x_i * x_i) .
+ *
+ * @param x x
+ * @return \sqrt(\sum_i x_i * x_i)
+ */
+ public static double norm2(DenseVector x) {
+ return JAVA_BLAS.dnrm2(x.size(), x.values, 1);
+ }
+
+ /**
+ * x = x * a .
+ *
+ * @param a a
+ * @param x x
+ */
+ public static void scal(double a, DenseVector x) {
+ JAVA_BLAS.dscal(x.size(), a, x.values, 1);
+ }
+
+ /**
+ * y := alpha * A * x + beta * y.
+ *
+ * @param matA m x n matrix A.
+ * @param transA transform matrix or not.
+ * @param x dense vector with size n.
+ * @param y dense vector with size m.
+ */
+ public static void gemv(
+ double alpha,
+ DenseMatrix matA,
+ boolean transA,
+ DenseVector x,
+ double beta,
+ DenseVector y) {
+ final int m = matA.numRows();
Review comment:
Can you add dimension check here similar as we did in `dot`?
Also, if the variable is used only once and does not contain much
information, can you just remove it? e.g., `m, n, lda`?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(2.0, 3.0), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(200.1, 300.1), 2.0),
+ Row.of(Vectors.dense(200.2, 300.2), 2.0),
+ Row.of(Vectors.dense(200.3, 300.3), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.6, 300.6), 2.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.8, 3.2), 3.0),
+ Row.of(Vectors.dense(300., 3.2), 4.0),
+ Row.of(Vectors.dense(2.2, 3.2), 1.0),
+ Row.of(Vectors.dense(2.4, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0)));
+
+ private static final List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(4.0, 4.1), 5.0),
+ Row.of(Vectors.dense(300, 42), 2.0)));
+ private Table testData;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.of(DenseVector.class))
+ .column("f1", DataTypes.DOUBLE())
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as("features",
"label");
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, schema).as("features",
"label");
+ }
+
+ private static List<Tuple2<Double, Double>> executeAndCollect(
+ Table output, String labelCol, String predictionCol) throws
Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
+
+ DataStream<Tuple2<Double, Double>> stream =
+ tEnv.toDataStream(output)
+ .map(
+ new MapFunction<Row, Tuple2<Double, Double>>()
{
+ @Override
+ public Tuple2<Double, Double> map(Row row)
{
+ return Tuple2.of(
+ (Double)
row.getField(labelCol),
+ (Double)
row.getField(predictionCol));
+ }
+ });
+ return IteratorUtils.toList(stream.executeAndCollect());
+ }
+
+ private static void verifyClusteringResult(List<Tuple2<Double, Double>>
result) {
+ for (Tuple2<Double, Double> t2 : result) {
+ Assert.assertEquals(t2.f0, t2.f1);
+ }
+ }
+
+ @Test
+ public void testParam() {
+ Knn knn = new Knn();
+ assertEquals("features", knn.getFeaturesCol());
+ assertEquals("label", knn.getLabelCol());
+ assertEquals(10L, knn.getK().longValue());
+ assertEquals("prediction", knn.getPredictionCol());
+
+ knn.setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setK(4)
+ .setPredictionCol("test_prediction");
+
+ assertEquals("test_features", knn.getFeaturesCol());
+ assertEquals("test_label", knn.getLabelCol());
+ assertEquals(4L, knn.getK().longValue());
+ assertEquals("test_prediction", knn.getPredictionCol());
+ }
+
+ @Test
+ public void testFeaturePredictionParam() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setK(4)
+ .setPredictionCol("test_prediction");
+ KnnModel model = knn.fit(trainData.as("test_features, test_label"));
+ Table output = model.transform(testData.as("test_features,
test_label"))[0];
+
+ assertEquals(
+ Arrays.asList("test_features", "test_label",
"test_prediction"),
+ output.getResolvedSchema().getColumnNames());
+
+ List<Tuple2<Double, Double>> result =
Review comment:
We probably do not need to execute the job here, since we are testing
`params`.
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(2.0, 3.0), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(200.1, 300.1), 2.0),
+ Row.of(Vectors.dense(200.2, 300.2), 2.0),
+ Row.of(Vectors.dense(200.3, 300.3), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.6, 300.6), 2.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.8, 3.2), 3.0),
+ Row.of(Vectors.dense(300., 3.2), 4.0),
+ Row.of(Vectors.dense(2.2, 3.2), 1.0),
+ Row.of(Vectors.dense(2.4, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0)));
+
+ private static final List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(4.0, 4.1), 5.0),
+ Row.of(Vectors.dense(300, 42), 2.0)));
+ private Table testData;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.of(DenseVector.class))
+ .column("f1", DataTypes.DOUBLE())
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as("features",
"label");
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, schema).as("features",
"label");
+ }
+
+ private static List<Tuple2<Double, Double>> executeAndCollect(
+ Table output, String labelCol, String predictionCol) throws
Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
+
+ DataStream<Tuple2<Double, Double>> stream =
+ tEnv.toDataStream(output)
+ .map(
+ new MapFunction<Row, Tuple2<Double, Double>>()
{
+ @Override
+ public Tuple2<Double, Double> map(Row row)
{
+ return Tuple2.of(
+ (Double)
row.getField(labelCol),
+ (Double)
row.getField(predictionCol));
+ }
+ });
+ return IteratorUtils.toList(stream.executeAndCollect());
+ }
+
+ private static void verifyClusteringResult(List<Tuple2<Double, Double>>
result) {
+ for (Tuple2<Double, Double> t2 : result) {
+ Assert.assertEquals(t2.f0, t2.f1);
+ }
+ }
+
+ @Test
+ public void testParam() {
+ Knn knn = new Knn();
+ assertEquals("features", knn.getFeaturesCol());
+ assertEquals("label", knn.getLabelCol());
+ assertEquals(10L, knn.getK().longValue());
Review comment:
Can you remove the `.longValue()` here and simply use `assertEquals(10,
knn.getK());`?
##########
File path:
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link Knn} and {@link KnnModel}. */
+public class KnnTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainData;
+ private static final List<Row> trainArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(2.0, 3.0), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(200.1, 300.1), 2.0),
+ Row.of(Vectors.dense(200.2, 300.2), 2.0),
+ Row.of(Vectors.dense(200.3, 300.3), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.4, 300.4), 2.0),
+ Row.of(Vectors.dense(200.6, 300.6), 2.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.3, 3.2), 1.0),
+ Row.of(Vectors.dense(2.8, 3.2), 3.0),
+ Row.of(Vectors.dense(300., 3.2), 4.0),
+ Row.of(Vectors.dense(2.2, 3.2), 1.0),
+ Row.of(Vectors.dense(2.4, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.5, 3.2), 5.0),
+ Row.of(Vectors.dense(2.1, 3.1), 1.0)));
+
+ private static final List<Row> testArray =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(Vectors.dense(4.0, 4.1), 5.0),
+ Row.of(Vectors.dense(300, 42), 2.0)));
+ private Table testData;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ Schema schema =
+ Schema.newBuilder()
+ .column("f0", DataTypes.of(DenseVector.class))
+ .column("f1", DataTypes.DOUBLE())
+ .build();
+
+ DataStream<Row> dataStream = env.fromCollection(trainArray);
+ trainData = tEnv.fromDataStream(dataStream, schema).as("features",
"label");
+
+ DataStream<Row> predDataStream = env.fromCollection(testArray);
+ testData = tEnv.fromDataStream(predDataStream, schema).as("features",
"label");
+ }
+
+ private static List<Tuple2<Double, Double>> executeAndCollect(
+ Table output, String labelCol, String predictionCol) throws
Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
+
+ DataStream<Tuple2<Double, Double>> stream =
+ tEnv.toDataStream(output)
+ .map(
+ new MapFunction<Row, Tuple2<Double, Double>>()
{
+ @Override
+ public Tuple2<Double, Double> map(Row row)
{
+ return Tuple2.of(
+ (Double)
row.getField(labelCol),
+ (Double)
row.getField(predictionCol));
+ }
+ });
+ return IteratorUtils.toList(stream.executeAndCollect());
+ }
+
+ private static void verifyClusteringResult(List<Tuple2<Double, Double>>
result) {
+ for (Tuple2<Double, Double> t2 : result) {
+ Assert.assertEquals(t2.f0, t2.f1);
+ }
+ }
+
+ @Test
+ public void testParam() {
+ Knn knn = new Knn();
+ assertEquals("features", knn.getFeaturesCol());
+ assertEquals("label", knn.getLabelCol());
+ assertEquals(10L, knn.getK().longValue());
+ assertEquals("prediction", knn.getPredictionCol());
+
+ knn.setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setK(4)
+ .setPredictionCol("test_prediction");
+
+ assertEquals("test_features", knn.getFeaturesCol());
+ assertEquals("test_label", knn.getLabelCol());
+ assertEquals(4L, knn.getK().longValue());
+ assertEquals("test_prediction", knn.getPredictionCol());
+ }
+
+ @Test
+ public void testFeaturePredictionParam() throws Exception {
+ Knn knn =
+ new Knn()
+ .setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setK(4)
+ .setPredictionCol("test_prediction");
+ KnnModel model = knn.fit(trainData.as("test_features, test_label"));
+ Table output = model.transform(testData.as("test_features,
test_label"))[0];
+
+ assertEquals(
+ Arrays.asList("test_features", "test_label",
"test_prediction"),
+ output.getResolvedSchema().getColumnNames());
+
+ List<Tuple2<Double, Double>> result =
+ executeAndCollect(output, "test_label", "test_prediction");
+ verifyClusteringResult(result);
+ }
+
+ @Test
+ public void testFewerDistinctPointsThanCluster() throws Exception {
+ Knn knn = new Knn().setK(4);
Review comment:
can we just use the default `K` here?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]