lindong28 commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r765407180
########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ########## @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { + /** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * <p>Note that we throw exception when one of the following two cases happen: + * <li>There exists one partition that contains more than one double array. + * <li>The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ + public static DataStream<double[]> allReduceSum(DataStream<double[]> input) { + return AllReduceImpl.allReduceSum(input); + } + + /** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param <IN> The class type of the input element. + * @param <OUT> The class type of output element. + * @return The result data stream. + */ + public static <IN, OUT> DataStream<OUT> mapPartition( + DataStream<IN> input, MapPartitionFunction<IN, OUT> func) { + TypeInformation<OUT> resultType = + TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); + return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) + .setParallelism(input.getParallelism()); + } + + public static <IN, OUT> DataStream<OUT> mapPartition( Review comment: Can we reuse the existing `public static <IN, OUT> DataStream<OUT> mapPartition(DataStream<IN> input, MapPartitionFunction<IN, OUT> func)` instead of creating this method? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertEquals; + +/** Tests Knn and KnnModel. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + private static final String LABEL_COL = "test_label"; + private static final String PRED_COL = "test_prediction"; + private static final String VEC_COL = "test_features"; + private static final List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of(1, Vectors.dense(2.0, 3.0)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(2, Vectors.dense(200.1, 300.1)), + Row.of(2, Vectors.dense(200.2, 300.2)), + Row.of(2, Vectors.dense(200.3, 300.3)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.6, 300.6)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(3, Vectors.dense(2.8, 3.2)), + Row.of(4, Vectors.dense(300., 3.2)), + Row.of(1, Vectors.dense(2.2, 3.2)), + Row.of(5, Vectors.dense(2.4, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(1, Vectors.dense(2.1, 3.1)))); + + private static final List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of(5, Vectors.dense(4.0, 4.1)), Row.of(2, Vectors.dense(300, 42)))); + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(DenseVector.class)) + .build(); + + DataStream<Row> dataStream = env.fromCollection(trainArray); + trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + "," + VEC_COL); Review comment: While this approach works, it seems less efficient to require the underlying implementation to parse the schema from the string. Could we use `as(LABEL_COL, VEC_COL)` here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** Knn model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** Constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Sets model data for knn prediction. + * + * @param modelData Knn model data. + * @return Knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * Gets model data. + * + * @return Table array including model data tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * Predicts label with knn model. + * + * @param inputs List of tables. + * @return Prediction result. + */ + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String broadcastKey = "broadcastModelKey"; + String resultCols = getPredictionCol(); + DataType resultTypes = DataTypes.INT(); + ResolvedSchema outputSchema = + TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), resultCols, resultTypes); Review comment: Instead of calling `TableUtils.getOutputSchema(...)` followed by `TableUtils.getRowTypeInfo(...)`, could we construct the `outputTypeInfo` directly using the same approach as `LogisticRegressionModel::transform`? This could help make Flink ML code more consistent and also reduce the complexity of `TableUtils`. ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java ########## @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.DenseMatrix; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@link DenseMatrix}. */ +public final class DenseMatrixSerializer extends TypeSerializerSingleton<DenseMatrix> { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY = new double[0]; + + private static final DenseMatrixSerializer INSTANCE = new DenseMatrixSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public DenseMatrix createInstance() { + return new DenseMatrix(0, 0, EMPTY); + } + + @Override + public DenseMatrix copy(DenseMatrix from) { + return new DenseMatrix( + from.numRows, from.numCols, Arrays.copyOf(from.values, from.values.length)); + } + + @Override + public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) { + if (from.values.length == reuse.values.length) { Review comment: What if the `reuse.numRows != from.numRows`? In this case, we can still reuse `reuse.values`, but we probably need to create a new `DenseMatrix` with the correct `numRows`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** Knn model data, which stores the data used to calculate the distances between nodes. */ +public class KnnModelData { + private final List<Tuple3<DenseMatrix, DenseVector, int[]>> dictData; + private final Comparator<? super Tuple2<Double, Integer>> comparator; Review comment: Could this `comparator` be `public static final`? It is not clear why this needs to be a non-static member variable and be initialized in the constructor. ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/DenseMatrixSerializer.java ########## @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.DenseMatrix; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@link DenseMatrix}. */ +public final class DenseMatrixSerializer extends TypeSerializerSingleton<DenseMatrix> { + + private static final long serialVersionUID = 1L; + + private static final double[] EMPTY = new double[0]; + + private static final DenseMatrixSerializer INSTANCE = new DenseMatrixSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public DenseMatrix createInstance() { + return new DenseMatrix(0, 0, EMPTY); + } + + @Override + public DenseMatrix copy(DenseMatrix from) { + return new DenseMatrix( + from.numRows, from.numCols, Arrays.copyOf(from.values, from.values.length)); + } + + @Override + public DenseMatrix copy(DenseMatrix from, DenseMatrix reuse) { + if (from.values.length == reuse.values.length) { + System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); + return reuse; + } + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(DenseMatrix matrix, DataOutputView target) throws IOException { + if (matrix == null) { + throw new IllegalArgumentException("The matrix must not be null."); + } + final int len = matrix.values.length; + target.writeInt(matrix.numRows); + target.writeInt(matrix.numCols); + for (int i = 0; i < len; i++) { + target.writeDouble(matrix.values[i]); + } + } + + @Override + public DenseMatrix deserialize(DataInputView source) throws IOException { + int m = source.readInt(); + int n = source.readInt(); + double[] values = new double[m * n]; + deserializeDoubleArray(values, source, m * n); + return new DenseMatrix(m, n, values); + } + + private static void deserializeDoubleArray(double[] dst, DataInputView source, int len) + throws IOException { + for (int i = 0; i < len; i++) { + dst[i] = source.readDouble(); + } + } + + @Override + public DenseMatrix deserialize(DenseMatrix reuse, DataInputView source) throws IOException { + int m = source.readInt(); + int n = source.readInt(); + double[] values = reuse.values; Review comment: Is it still correct to do so if `reuse.values.length != m * n`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** Knn model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** Constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Sets model data for knn prediction. + * + * @param modelData Knn model data. + * @return Knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * Gets model data. + * + * @return Table array including model data tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * Predicts label with knn model. + * + * @param inputs List of tables. + * @return Prediction result. + */ + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String broadcastKey = "broadcastModelKey"; + String resultCols = getPredictionCol(); + DataType resultTypes = DataTypes.INT(); + ResolvedSchema outputSchema = + TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), resultCols, resultTypes); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), + Collections.singletonMap(broadcastKey, model), + inputList -> { + DataStream inoutData = inputList.get(0); + return inoutData.transform( + "mapFunc", + TableUtils.getRowTypeInfo(outputSchema), + new PredictOperator(broadcastKey, getK(), getFeaturesCol())); + }); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** @return Parameters for algorithm. */ + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + /** + * Saves model data. + * + * @param path Path to save. + */ + @Override + public void save(String path) throws IOException { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData[0]).getTableEnvironment(); + + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<Row> sink = + FileSink.forRowFormat(new Path(dataPath), new KnnModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + tEnv.toDataStream(modelData[0]).sinkTo(sink); + ReadWriteUtils.saveMetadata(this, path); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + KnnModel retModel = ReadWriteUtils.loadStageParam(path); + + Source<Row, ?, ?> source = + FileSource.forRecordStreamFormat( + new KnnModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + DataStream<Row> modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "data"); + retModel.modelData = + new Table[] {tEnv.fromDataStream(modelDataStream, KnnModelData.getModelSchema())}; + return retModel; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOperator + extends AbstractUdfStreamOperator<Row, AbstractRichFunction> + implements OneInputStreamOperator<Row, Row> { + + private boolean firstEle = true; + private final String featureCol; + private transient KnnModelData modelData; + private final Integer topN; Review comment: nits: we use `int` instead of `Integer` when possible. Could we use `int` here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** Knn model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** Constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Sets model data for knn prediction. + * + * @param modelData Knn model data. + * @return Knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * Gets model data. + * + * @return Table array including model data tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * Predicts label with knn model. + * + * @param inputs List of tables. + * @return Prediction result. + */ + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String broadcastKey = "broadcastModelKey"; + String resultCols = getPredictionCol(); + DataType resultTypes = DataTypes.INT(); + ResolvedSchema outputSchema = + TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), resultCols, resultTypes); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), + Collections.singletonMap(broadcastKey, model), + inputList -> { + DataStream inoutData = inputList.get(0); + return inoutData.transform( + "mapFunc", + TableUtils.getRowTypeInfo(outputSchema), + new PredictOperator(broadcastKey, getK(), getFeaturesCol())); + }); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** @return Parameters for algorithm. */ + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + /** + * Saves model data. + * + * @param path Path to save. + */ + @Override + public void save(String path) throws IOException { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData[0]).getTableEnvironment(); + + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<Row> sink = + FileSink.forRowFormat(new Path(dataPath), new KnnModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + tEnv.toDataStream(modelData[0]).sinkTo(sink); + ReadWriteUtils.saveMetadata(this, path); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + KnnModel retModel = ReadWriteUtils.loadStageParam(path); + + Source<Row, ?, ?> source = + FileSource.forRecordStreamFormat( + new KnnModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + DataStream<Row> modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "data"); + retModel.modelData = + new Table[] {tEnv.fromDataStream(modelDataStream, KnnModelData.getModelSchema())}; + return retModel; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOperator + extends AbstractUdfStreamOperator<Row, AbstractRichFunction> + implements OneInputStreamOperator<Row, Row> { + + private boolean firstEle = true; Review comment: nits: it does not seem comment to use `Ele` as abbreviation of `Element`. Could we use `firstElement` here? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java ########## @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.util.Preconditions; + +/** A utility class that provides BLAS routines over matrices and vectors. */ +public class BLAS { + + private static final dev.ludovic.netlib.BLAS NATIVE_BLAS = + dev.ludovic.netlib.BLAS.getInstance(); + + /** + * \sum_i |x_i| . + * + * @param x x + * @return \sum_i |x_i| + */ + public static double asum(DenseVector x) { + return NATIVE_BLAS.dasum(x.size(), x.values, 0, 1); + } + + /** + * y += a * x . + * + * @param a a + * @param x x + * @param y y + */ + public static void axpy(double a, DenseVector x, DenseVector y) { + Preconditions.checkArgument(x.size() == y.size(), "Array dimension mismatched."); + NATIVE_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1); + } + + /** + * x \cdot y . + * + * @param x x + * @param y y + * @return x \cdot y + */ + public static double dot(DenseVector x, DenseVector y) { + Preconditions.checkArgument(x.size() == y.size(), "Array dimension mismatched."); + return NATIVE_BLAS.ddot(x.size(), x.values, 1, y.values, 1); + } + + /** + * \sqrt(\sum_i x_i * x_i) . + * + * @param x x + * @return \sqrt(\sum_i x_i * x_i) + */ + public static double norm2(DenseVector x) { + return NATIVE_BLAS.dnrm2(x.size(), x.values, 1); + } + + /** + * x = x * a . + * + * @param a a + * @param x x + */ + public static void scal(double a, DenseVector x) { + NATIVE_BLAS.dscal(x.size(), a, x.values, 1); + } + + /** y := alpha * A * x + beta * y . */ Review comment: Could we make the Java doc consistent with the variable names, e.g. replace `matA` with `A`. Unlike other methods in BLAS, this one is a bit more complex to understand (e.g. it uses `transA`). Could we add Java doc similar to the `BLAS::gemv(...)` in Spark? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertEquals; + +/** Tests Knn and KnnModel. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + private static final String LABEL_COL = "test_label"; + private static final String PRED_COL = "test_prediction"; + private static final String VEC_COL = "test_features"; + private static final List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of(1, Vectors.dense(2.0, 3.0)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(2, Vectors.dense(200.1, 300.1)), + Row.of(2, Vectors.dense(200.2, 300.2)), + Row.of(2, Vectors.dense(200.3, 300.3)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.6, 300.6)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(3, Vectors.dense(2.8, 3.2)), + Row.of(4, Vectors.dense(300., 3.2)), + Row.of(1, Vectors.dense(2.2, 3.2)), + Row.of(5, Vectors.dense(2.4, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(1, Vectors.dense(2.1, 3.1)))); + + private static final List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of(5, Vectors.dense(4.0, 4.1)), Row.of(2, Vectors.dense(300, 42)))); + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(DenseVector.class)) + .build(); + + DataStream<Row> dataStream = env.fromCollection(trainArray); + trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + "," + VEC_COL); + + DataStream<Row> predDataStream = env.fromCollection(testArray); + testData = tEnv.fromDataStream(predDataStream, schema).as(LABEL_COL + "," + VEC_COL); + } + + // Executes the graph and returns a list which has true label and predict label. + private static List<Tuple2<String, String>> executeAndCollect(Table output) throws Exception { Review comment: Should the output type be `List<Tuple2<Integer, Integer>>`? Could you help check that other usages of `String` has been corrected in `KnnTest`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertEquals; + +/** Tests Knn and KnnModel. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + private static final String LABEL_COL = "test_label"; + private static final String PRED_COL = "test_prediction"; + private static final String VEC_COL = "test_features"; + private static final List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of(1, Vectors.dense(2.0, 3.0)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(2, Vectors.dense(200.1, 300.1)), + Row.of(2, Vectors.dense(200.2, 300.2)), + Row.of(2, Vectors.dense(200.3, 300.3)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.6, 300.6)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(3, Vectors.dense(2.8, 3.2)), + Row.of(4, Vectors.dense(300., 3.2)), + Row.of(1, Vectors.dense(2.2, 3.2)), + Row.of(5, Vectors.dense(2.4, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(1, Vectors.dense(2.1, 3.1)))); + + private static final List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of(5, Vectors.dense(4.0, 4.1)), Row.of(2, Vectors.dense(300, 42)))); + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) Review comment: Previously I also think it is cleaner to use `int` to represent categorical field. And `KMeans` currently outputs the prediction column as `int`. However, @zhipeng93 said most algorithm users prefer the output/input type to be consistently `double`. And Spark also use `double` to represent categorical field. While I don't have a strong opinion on this issue, IMO it is probably useful to follow the same practice in Flink ML. Could you check with @zhipeng93 and decide whether we should use int or double here? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertEquals; + +/** Tests Knn and KnnModel. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + private static final String LABEL_COL = "test_label"; Review comment: It is not clear why we need to explicitly set column names for every test. Any chance that we could simplify the test code, and also make our test code consistent with each other, by not using `LABEL_COL/VEC_COL/PRED_COL` here? Maybe follow `LogisticRegressionTest` for example? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** Knn model data, which stores the data used to calculate the distances between nodes. */ +public class KnnModelData { + private final List<Tuple3<DenseMatrix, DenseVector, int[]>> dictData; Review comment: Could we name this field as `modelData` for consistency with other algorithms? Could we add Java doc explaining the meaning of these three fields? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** Knn model data, which stores the data used to calculate the distances between nodes. */ +public class KnnModelData { + private final List<Tuple3<DenseMatrix, DenseVector, int[]>> dictData; + private final Comparator<? super Tuple2<Double, Integer>> comparator; + + /** + * Constructor. + * + * @param list Row list. + */ + public KnnModelData(List<Row> list) { + this.dictData = new ArrayList<>(list.size()); + for (Row row : list) { + this.dictData.add( + Tuple3.of( + (DenseMatrix) row.getField(0), + (DenseVector) row.getField(1), + (int[]) row.getField(2))); + } + comparator = Comparator.comparingDouble(o -> -o.f0); + } + + /** + * Gets comparator. + * + * @return Comparator. + */ + public Comparator<? super Tuple2<Double, Integer>> getQueueComparator() { + return comparator; + } + + /** + * Gets dictionary data size. + * + * @return Dictionary data size. + */ + public Integer getLength() { + return dictData.size(); + } + + public List<Tuple3<DenseMatrix, DenseVector, int[]>> getDictData() { + return dictData; + } + + /** Encoder for the Knn model data. */ + public static class ModelDataEncoder implements Encoder<Row> { + @Override + public void encode(Row modelData, OutputStream outputStream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); + + DenseMatrixSerializer matrixSerializer = new DenseMatrixSerializer(); + matrixSerializer.serialize((DenseMatrix) modelData.getField(0), dataOutputView); + + DenseVectorSerializer vectorSerializer = new DenseVectorSerializer(); + vectorSerializer.serialize((DenseVector) modelData.getField(1), dataOutputView); + + int[] label = (int[]) Objects.requireNonNull(modelData.getField(2)); + for (Integer integer : label) { + dataOutputView.writeInt(integer); + } + } + } + + /** Decoder for the Knn model data. */ + public static class ModelDataStreamFormat extends SimpleStreamFormat<Row> { + + @Override + public Reader<Row> createReader(Configuration config, FSDataInputStream stream) { + return new Reader<Row>() { + + @Override + public Row read() throws IOException { + try { + DataInputView source = new DataInputViewStreamWrapper(stream); Review comment: Could we make `source`, `DenseMatrixSerializer()` and `DenseVectorSerializer()` private final fields of the anonymous `Reader` class, to avoid creating these instances repeated for every invocation of read()? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** Knn model data, which stores the data used to calculate the distances between nodes. */ +public class KnnModelData { + private final List<Tuple3<DenseMatrix, DenseVector, int[]>> dictData; + private final Comparator<? super Tuple2<Double, Integer>> comparator; + + /** + * Constructor. + * + * @param list Row list. + */ + public KnnModelData(List<Row> list) { + this.dictData = new ArrayList<>(list.size()); + for (Row row : list) { + this.dictData.add( + Tuple3.of( + (DenseMatrix) row.getField(0), + (DenseVector) row.getField(1), + (int[]) row.getField(2))); + } + comparator = Comparator.comparingDouble(o -> -o.f0); + } + + /** + * Gets comparator. + * + * @return Comparator. + */ + public Comparator<? super Tuple2<Double, Integer>> getQueueComparator() { + return comparator; + } + + /** + * Gets dictionary data size. + * + * @return Dictionary data size. + */ + public Integer getLength() { + return dictData.size(); + } + + public List<Tuple3<DenseMatrix, DenseVector, int[]>> getDictData() { Review comment: Would it be simpler (and also more consistent with other algorithms) to make `dictData` a `public final` field and let caller access this field directly? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** Knn model data, which stores the data used to calculate the distances between nodes. */ +public class KnnModelData { + private final List<Tuple3<DenseMatrix, DenseVector, int[]>> dictData; + private final Comparator<? super Tuple2<Double, Integer>> comparator; + + /** + * Constructor. + * + * @param list Row list. + */ + public KnnModelData(List<Row> list) { + this.dictData = new ArrayList<>(list.size()); + for (Row row : list) { + this.dictData.add( + Tuple3.of( + (DenseMatrix) row.getField(0), + (DenseVector) row.getField(1), + (int[]) row.getField(2))); + } + comparator = Comparator.comparingDouble(o -> -o.f0); + } + + /** + * Gets comparator. + * + * @return Comparator. + */ + public Comparator<? super Tuple2<Double, Integer>> getQueueComparator() { + return comparator; + } + + /** + * Gets dictionary data size. + * + * @return Dictionary data size. + */ + public Integer getLength() { + return dictData.size(); + } + + public List<Tuple3<DenseMatrix, DenseVector, int[]>> getDictData() { + return dictData; + } + + /** Encoder for the Knn model data. */ + public static class ModelDataEncoder implements Encoder<Row> { + @Override + public void encode(Row modelData, OutputStream outputStream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); + + DenseMatrixSerializer matrixSerializer = new DenseMatrixSerializer(); + matrixSerializer.serialize((DenseMatrix) modelData.getField(0), dataOutputView); + + DenseVectorSerializer vectorSerializer = new DenseVectorSerializer(); + vectorSerializer.serialize((DenseVector) modelData.getField(1), dataOutputView); + + int[] label = (int[]) Objects.requireNonNull(modelData.getField(2)); + for (Integer integer : label) { + dataOutputView.writeInt(integer); + } + } + } + + /** Decoder for the Knn model data. */ + public static class ModelDataStreamFormat extends SimpleStreamFormat<Row> { + + @Override + public Reader<Row> createReader(Configuration config, FSDataInputStream stream) { + return new Reader<Row>() { + + @Override + public Row read() throws IOException { + try { + DataInputView source = new DataInputViewStreamWrapper(stream); + DenseMatrix matrix = new DenseMatrixSerializer().deserialize(source); + DenseVector vector = new DenseVectorSerializer().deserialize(source); + int[] label = new int[vector.size()]; + for (int i = 0; i < label.length; ++i) { + label[i] = source.readInt(); + } + return Row.of(matrix, vector, label); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation<Row> getProducedType() { + return getRowTypeInfo(); + } + } + + protected static Schema getModelSchema() { + return Schema.newBuilder() + .column("VECTORS", DataTypes.of(DenseMatrix.class)) Review comment: For simplicity and consistency, could we use the same approach as `KMeansModelData` and `LogisticRegressionModelData` by using `f0, f1, f2` as the field names? If we decide to use human readable names here, it will be nice to enforce this practice consistently across algorithms. And we probably want to use lowercase letters here so that the name of these fields are consistent with the default values of e.g. `HasFeaturesCol`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.types.Row; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Objects; + +/** Knn model data, which stores the data used to calculate the distances between nodes. */ +public class KnnModelData { + private final List<Tuple3<DenseMatrix, DenseVector, int[]>> dictData; + private final Comparator<? super Tuple2<Double, Integer>> comparator; + + /** + * Constructor. + * + * @param list Row list. + */ + public KnnModelData(List<Row> list) { + this.dictData = new ArrayList<>(list.size()); + for (Row row : list) { + this.dictData.add( + Tuple3.of( + (DenseMatrix) row.getField(0), + (DenseVector) row.getField(1), + (int[]) row.getField(2))); + } + comparator = Comparator.comparingDouble(o -> -o.f0); + } + + /** + * Gets comparator. + * + * @return Comparator. + */ + public Comparator<? super Tuple2<Double, Integer>> getQueueComparator() { + return comparator; + } + + /** + * Gets dictionary data size. + * + * @return Dictionary data size. + */ + public Integer getLength() { + return dictData.size(); + } + + public List<Tuple3<DenseMatrix, DenseVector, int[]>> getDictData() { + return dictData; + } + + /** Encoder for the Knn model data. */ + public static class ModelDataEncoder implements Encoder<Row> { + @Override + public void encode(Row modelData, OutputStream outputStream) throws IOException { + DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream); + + DenseMatrixSerializer matrixSerializer = new DenseMatrixSerializer(); + matrixSerializer.serialize((DenseMatrix) modelData.getField(0), dataOutputView); + + DenseVectorSerializer vectorSerializer = new DenseVectorSerializer(); + vectorSerializer.serialize((DenseVector) modelData.getField(1), dataOutputView); + + int[] label = (int[]) Objects.requireNonNull(modelData.getField(2)); + for (Integer integer : label) { + dataOutputView.writeInt(integer); + } + } + } + + /** Decoder for the Knn model data. */ + public static class ModelDataStreamFormat extends SimpleStreamFormat<Row> { + + @Override + public Reader<Row> createReader(Configuration config, FSDataInputStream stream) { + return new Reader<Row>() { + + @Override + public Row read() throws IOException { + try { + DataInputView source = new DataInputViewStreamWrapper(stream); + DenseMatrix matrix = new DenseMatrixSerializer().deserialize(source); + DenseVector vector = new DenseVectorSerializer().deserialize(source); + int[] label = new int[vector.size()]; + for (int i = 0; i < label.length; ++i) { + label[i] = source.readInt(); + } + return Row.of(matrix, vector, label); + } catch (EOFException e) { + return null; + } + } + + @Override + public void close() throws IOException { + stream.close(); + } + }; + } + + @Override + public TypeInformation<Row> getProducedType() { Review comment: `KMeansModelData` and `LogisticRegressionModelData` do not override `getProducedType`. Could we do the same here for simplicity? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** Knn model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** Constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Sets model data for knn prediction. + * + * @param modelData Knn model data. + * @return Knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * Gets model data. + * + * @return Table array including model data tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * Predicts label with knn model. + * + * @param inputs List of tables. + * @return Prediction result. + */ + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String broadcastKey = "broadcastModelKey"; + String resultCols = getPredictionCol(); + DataType resultTypes = DataTypes.INT(); + ResolvedSchema outputSchema = + TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), resultCols, resultTypes); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), + Collections.singletonMap(broadcastKey, model), + inputList -> { + DataStream inoutData = inputList.get(0); + return inoutData.transform( + "mapFunc", + TableUtils.getRowTypeInfo(outputSchema), + new PredictOperator(broadcastKey, getK(), getFeaturesCol())); + }); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** @return Parameters for algorithm. */ + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + /** + * Saves model data. + * + * @param path Path to save. + */ + @Override + public void save(String path) throws IOException { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData[0]).getTableEnvironment(); + + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<Row> sink = + FileSink.forRowFormat(new Path(dataPath), new KnnModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + tEnv.toDataStream(modelData[0]).sinkTo(sink); + ReadWriteUtils.saveMetadata(this, path); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + KnnModel retModel = ReadWriteUtils.loadStageParam(path); + + Source<Row, ?, ?> source = + FileSource.forRecordStreamFormat( + new KnnModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + DataStream<Row> modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "data"); + retModel.modelData = + new Table[] {tEnv.fromDataStream(modelDataStream, KnnModelData.getModelSchema())}; + return retModel; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOperator + extends AbstractUdfStreamOperator<Row, AbstractRichFunction> + implements OneInputStreamOperator<Row, Row> { + + private boolean firstEle = true; + private final String featureCol; + private transient KnnModelData modelData; + private final Integer topN; + private final String broadcastKey; + + public PredictOperator(String broadcastKey, int k, String featureCol) { + super(new AbstractRichFunction() {}); + this.topN = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public void processElement(StreamRecord<Row> streamRecord) { + Row value = streamRecord.getValue(); + output.collect(new StreamRecord<>(map(value))); + } + + private Row map(Row row) { + if (firstEle) { + loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey)); + firstEle = false; + } + DenseVector vector = (DenseVector) row.getField(featureCol); + Tuple2<List<Integer>, List<Double>> t2 = findNeighbor(vector, topN, modelData); + Row ret = new Row(row.getArity() + 1); + for (int i = 0; i < row.getArity(); ++i) { + ret.setField(i, row.getField(i)); + } + + ret.setField(row.getArity(), getResult(t2)); + return ret; + } + + /** + * Finds the nearest topN neighbors from whole nodes. + * + * @param input Input vector. + * @param topN Top N. + * @return Neighbors. + */ + private Tuple2<List<Integer>, List<Double>> findNeighbor( + DenseVector input, Integer topN, KnnModelData modelData) { + PriorityQueue<Tuple2<Double, Integer>> priorityQueue = + new PriorityQueue<>(modelData.getQueueComparator()); Review comment: Would it be simpler to instantiate `Comparator.comparingDouble(o -> -o.f0)` as a member field of `PredictOperator` directly instead of calling `getQueueComparator(...)`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,273 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +import static org.junit.Assert.assertEquals; + +/** Tests Knn and KnnModel. */ +public class KnnTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + private static final String LABEL_COL = "test_label"; + private static final String PRED_COL = "test_prediction"; + private static final String VEC_COL = "test_features"; + private static final List<Row> trainArray = + new ArrayList<>( + Arrays.asList( + Row.of(1, Vectors.dense(2.0, 3.0)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(2, Vectors.dense(200.1, 300.1)), + Row.of(2, Vectors.dense(200.2, 300.2)), + Row.of(2, Vectors.dense(200.3, 300.3)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.4, 300.4)), + Row.of(2, Vectors.dense(200.6, 300.6)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.1, 3.1)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(1, Vectors.dense(2.3, 3.2)), + Row.of(3, Vectors.dense(2.8, 3.2)), + Row.of(4, Vectors.dense(300., 3.2)), + Row.of(1, Vectors.dense(2.2, 3.2)), + Row.of(5, Vectors.dense(2.4, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(5, Vectors.dense(2.5, 3.2)), + Row.of(1, Vectors.dense(2.1, 3.1)))); + + private static final List<Row> testArray = + new ArrayList<>( + Arrays.asList( + Row.of(5, Vectors.dense(4.0, 4.1)), Row.of(2, Vectors.dense(300, 42)))); + private Table testData; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(DenseVector.class)) + .build(); + + DataStream<Row> dataStream = env.fromCollection(trainArray); + trainData = tEnv.fromDataStream(dataStream, schema).as(LABEL_COL + "," + VEC_COL); + + DataStream<Row> predDataStream = env.fromCollection(testArray); + testData = tEnv.fromDataStream(predDataStream, schema).as(LABEL_COL + "," + VEC_COL); + } + + // Executes the graph and returns a list which has true label and predict label. + private static List<Tuple2<String, String>> executeAndCollect(Table output) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + + DataStream<Tuple2<Integer, Integer>> stream = + tEnv.toDataStream(output) + .map( + new MapFunction<Row, Tuple2<Integer, Integer>>() { + @Override + public Tuple2<Integer, Integer> map(Row row) { + return Tuple2.of( + (Integer) row.getField(LABEL_COL), + (Integer) row.getField(PRED_COL)); + } + }); + return IteratorUtils.toList(stream.executeAndCollect()); + } + + private static void verifyClusteringResult(List<Tuple2<String, String>> result) { + for (Tuple2<String, String> t2 : result) { + Assert.assertEquals(t2.f0, t2.f1); + } + } + + /** Tests Param. */ + @Test + public void testParam() { + Knn knn = new Knn(); + assertEquals("features", knn.getFeaturesCol()); + assertEquals("label", knn.getLabelCol()); + assertEquals(10L, knn.getK().longValue()); + assertEquals("prediction", knn.getPredictionCol()); + + knn.setLabelCol(LABEL_COL).setFeaturesCol(VEC_COL).setK(4).setPredictionCol(PRED_COL); + + assertEquals(VEC_COL, knn.getFeaturesCol()); + assertEquals(LABEL_COL, knn.getLabelCol()); + assertEquals(4L, knn.getK().longValue()); + assertEquals(PRED_COL, knn.getPredictionCol()); + } + + @Test + public void testFeaturePredictionParam() throws Exception { + Knn knn = + new Knn() + .setLabelCol(LABEL_COL) + .setFeaturesCol(VEC_COL) + .setK(4) + .setPredictionCol(PRED_COL); + KnnModel model = knn.fit(trainData); + Table output = model.transform(testData)[0]; + + assertEquals( + Arrays.asList(LABEL_COL, VEC_COL, PRED_COL), + output.getResolvedSchema().getColumnNames()); + + List<Tuple2<String, String>> result = executeAndCollect(output); + verifyClusteringResult(result); + } + + @Test + public void testFewerDistinctPointsThanCluster() throws Exception { + Knn knn = + new Knn() + .setLabelCol(LABEL_COL) + .setFeaturesCol(VEC_COL) + .setK(4) + .setPredictionCol(PRED_COL); + KnnModel model = knn.fit(testData); + Table output = model.transform(testData)[0]; + + assertEquals( + Arrays.asList(LABEL_COL, VEC_COL, PRED_COL), + output.getResolvedSchema().getColumnNames()); + executeAndCollect(output); + } + + @Test + public void testFitAndPredict() throws Exception { + Knn knn = + new Knn() + .setLabelCol(LABEL_COL) + .setFeaturesCol(VEC_COL) + .setK(4) + .setPredictionCol(PRED_COL); + KnnModel knnModel = knn.fit(trainData); + Table output = knnModel.transform(testData)[0]; + List<Tuple2<String, String>> result = executeAndCollect(output); + verifyClusteringResult(result); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + String path = Files.createTempDirectory("").toString(); + Knn knn = + new Knn() + .setLabelCol(LABEL_COL) + .setFeaturesCol(VEC_COL) + .setK(4) + .setPredictionCol(PRED_COL); + knn.save(path); + + Knn loadKnn = Knn.load(path); + KnnModel knnModel = loadKnn.fit(trainData); + Table output = knnModel.transform(testData)[0]; + + List<Tuple2<String, String>> result = executeAndCollect(output); + verifyClusteringResult(result); + } + + @Test + public void testModelSaveLoadAndPredict() throws Exception { + String path = Files.createTempDirectory("").toString(); + Knn knn = + new Knn() + .setLabelCol(LABEL_COL) + .setFeaturesCol(VEC_COL) + .setK(4) + .setPredictionCol(PRED_COL); + KnnModel knnModel = knn.fit(trainData); + knnModel.save(path); + env.execute(); + + KnnModel newModel = KnnModel.load(env, path); + Table output = newModel.transform(testData)[0]; + List<Tuple2<String, String>> result = executeAndCollect(output); + verifyClusteringResult(result); + } + + @Test + public void testGetModelData() throws Exception { + Knn knn = + new Knn() + .setLabelCol(LABEL_COL) + .setFeaturesCol(VEC_COL) + .setK(4) + .setPredictionCol(PRED_COL); + + KnnModel knnModel = knn.fit(trainData); + Table modelData = knnModel.getModelData()[0]; + + DataStream<Row> output = tEnv.toDataStream(modelData); + + assertEquals( + Arrays.asList("VECTORS", "NORM", "LABEL"), + modelData.getResolvedSchema().getColumnNames()); + + List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect()); + for (Row modelRow : modelRows) { + DenseMatrix vectors = (DenseMatrix) modelRow.getField(0); + DenseVector label = (DenseVector) Objects.requireNonNull(modelRow.getField(1)); + + assertEquals(2, Objects.requireNonNull(vectors).numRows); Review comment: Could we simplify the code here by removing `Objects.requireNonNull`? It looks like the behavior of this line when `vectors = null` would be the same after we remove it? In general in Flink codebase we don't explicitly do `requireNonNull(...)` and it is OK to just throw NPE. Could you check other usages of `requireNonNull(...)` and follow the existing practice in Flink? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** Knn model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** Constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Sets model data for knn prediction. + * + * @param modelData Knn model data. + * @return Knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * Gets model data. + * + * @return Table array including model data tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * Predicts label with knn model. + * + * @param inputs List of tables. + * @return Prediction result. + */ + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String broadcastKey = "broadcastModelKey"; + String resultCols = getPredictionCol(); + DataType resultTypes = DataTypes.INT(); + ResolvedSchema outputSchema = + TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), resultCols, resultTypes); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), + Collections.singletonMap(broadcastKey, model), + inputList -> { + DataStream inoutData = inputList.get(0); + return inoutData.transform( + "mapFunc", + TableUtils.getRowTypeInfo(outputSchema), + new PredictOperator(broadcastKey, getK(), getFeaturesCol())); + }); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** @return Parameters for algorithm. */ + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + /** + * Saves model data. + * + * @param path Path to save. + */ + @Override + public void save(String path) throws IOException { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData[0]).getTableEnvironment(); + + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<Row> sink = + FileSink.forRowFormat(new Path(dataPath), new KnnModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + tEnv.toDataStream(modelData[0]).sinkTo(sink); + ReadWriteUtils.saveMetadata(this, path); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + KnnModel retModel = ReadWriteUtils.loadStageParam(path); + + Source<Row, ?, ?> source = + FileSource.forRecordStreamFormat( + new KnnModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + DataStream<Row> modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "data"); + retModel.modelData = + new Table[] {tEnv.fromDataStream(modelDataStream, KnnModelData.getModelSchema())}; + return retModel; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOperator + extends AbstractUdfStreamOperator<Row, AbstractRichFunction> + implements OneInputStreamOperator<Row, Row> { + + private boolean firstEle = true; + private final String featureCol; + private transient KnnModelData modelData; + private final Integer topN; + private final String broadcastKey; + + public PredictOperator(String broadcastKey, int k, String featureCol) { + super(new AbstractRichFunction() {}); + this.topN = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public void processElement(StreamRecord<Row> streamRecord) { + Row value = streamRecord.getValue(); + output.collect(new StreamRecord<>(map(value))); + } + + private Row map(Row row) { + if (firstEle) { + loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey)); + firstEle = false; + } + DenseVector vector = (DenseVector) row.getField(featureCol); + Tuple2<List<Integer>, List<Double>> t2 = findNeighbor(vector, topN, modelData); + Row ret = new Row(row.getArity() + 1); + for (int i = 0; i < row.getArity(); ++i) { + ret.setField(i, row.getField(i)); + } + + ret.setField(row.getArity(), getResult(t2)); + return ret; + } + + /** + * Finds the nearest topN neighbors from whole nodes. + * + * @param input Input vector. + * @param topN Top N. + * @return Neighbors. + */ + private Tuple2<List<Integer>, List<Double>> findNeighbor( + DenseVector input, Integer topN, KnnModelData modelData) { + PriorityQueue<Tuple2<Double, Integer>> priorityQueue = + new PriorityQueue<>(modelData.getQueueComparator()); + search(input, topN, priorityQueue, modelData); + List<Integer> items = new ArrayList<>(); + List<Double> metrics = new ArrayList<>(); + while (!priorityQueue.isEmpty()) { + Tuple2<Double, Integer> result = priorityQueue.poll(); + items.add(result.f1); + metrics.add(result.f0); + } + Collections.reverse(items); + Collections.reverse(metrics); + priorityQueue.clear(); + return Tuple2.of(items, metrics); + } + + /** + * @param input Input vector. + * @param topN Top N. + * @param priorityQueue Priority queue. + */ + private void search( + DenseVector input, + Integer topN, + PriorityQueue<Tuple2<Double, Integer>> priorityQueue, + KnnModelData modelData) { + Tuple2<DenseVector, Double> sample = computeNorm(input); + Tuple2<Double, Integer> head = null; + for (int i = 0; i < modelData.getLength(); i++) { + List<Tuple2<Double, Integer>> values = computeDistance(sample, i); + for (Tuple2<Double, Integer> currentValue : values) { + head = updateQueue(priorityQueue, topN, currentValue, head); + } + } + } + + /** + * Updates queue. + * + * @param pq Queue. + * @param topN Top N. + * @param newValue New value. + * @param head Head value. + * @return Head value. + */ + private <T> Tuple2<Double, T> updateQueue( + PriorityQueue<Tuple2<Double, T>> pq, + int topN, + Tuple2<Double, T> newValue, + Tuple2<Double, T> head) { + if (pq.size() < topN) { + pq.add(Tuple2.of(newValue.f0, newValue.f1)); + head = pq.peek(); + } else { + if (pq.comparator().compare(head, newValue) < 0) { + Tuple2<Double, T> peek = pq.poll(); + assert peek != null; + peek.f0 = newValue.f0; + peek.f1 = newValue.f1; + pq.add(peek); + head = pq.peek(); + } + } + return head; + } + + /** + * Computes distance between sample and dictionary vectors. + * + * @param input Samples with l2 norm. + * @param index Dictionary vectors index. + * @return Distances. + */ + private List<Tuple2<Double, Integer>> computeDistance( + Tuple2<DenseVector, Double> input, Integer index) { + Tuple3<DenseMatrix, DenseVector, int[]> data = modelData.getDictData().get(index); + + DenseMatrix vectors = data.f0; + DenseMatrix distanceMatrix = new DenseMatrix(Objects.requireNonNull(vectors).numCols, 1); + + DenseVector norm = data.f1; + double[] normL2Square = Objects.requireNonNull(norm).values; + BLAS.gemv(-2.0, vectors, true, input.f0, 0.0, new DenseVector(distanceMatrix.values)); + for (int i = 0; i < distanceMatrix.values.length; i++) { + distanceMatrix.values[i] = Math.sqrt(Math.abs(distanceMatrix.values[i] + input.f1 + normL2Square[i])); + } + + List<Tuple2<Double, Integer>> list = new ArrayList<>(0); Review comment: nits: the initial capacity could be set to `curLabels.length` for better performance. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,378 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** Knn model fitted by estimator. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table[] modelData; + + /** Constructor. */ + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Sets model data for knn prediction. + * + * @param modelData Knn model data. + * @return Knn model. + */ + @Override + public KnnModel setModelData(Table... modelData) { + this.modelData = modelData; + return this; + } + + /** + * Gets model data. + * + * @return Table array including model data tables. + */ + @Override + public Table[] getModelData() { + return modelData; + } + + /** + * Predicts label with knn model. + * + * @param inputs List of tables. + * @return Prediction result. + */ + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Row> model = tEnv.toDataStream(modelData[0]); + final String broadcastKey = "broadcastModelKey"; + String resultCols = getPredictionCol(); + DataType resultTypes = DataTypes.INT(); + ResolvedSchema outputSchema = + TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), resultCols, resultTypes); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), + Collections.singletonMap(broadcastKey, model), + inputList -> { + DataStream inoutData = inputList.get(0); + return inoutData.transform( + "mapFunc", + TableUtils.getRowTypeInfo(outputSchema), + new PredictOperator(broadcastKey, getK(), getFeaturesCol())); + }); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** @return Parameters for algorithm. */ + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + /** + * Saves model data. + * + * @param path Path to save. + */ + @Override + public void save(String path) throws IOException { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelData[0]).getTableEnvironment(); + + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<Row> sink = + FileSink.forRowFormat(new Path(dataPath), new KnnModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + tEnv.toDataStream(modelData[0]).sinkTo(sink); + ReadWriteUtils.saveMetadata(this, path); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + KnnModel retModel = ReadWriteUtils.loadStageParam(path); + + Source<Row, ?, ?> source = + FileSource.forRecordStreamFormat( + new KnnModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + DataStream<Row> modelDataStream = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "data"); + retModel.modelData = + new Table[] {tEnv.fromDataStream(modelDataStream, KnnModelData.getModelSchema())}; + return retModel; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOperator + extends AbstractUdfStreamOperator<Row, AbstractRichFunction> + implements OneInputStreamOperator<Row, Row> { + + private boolean firstEle = true; + private final String featureCol; + private transient KnnModelData modelData; + private final Integer topN; + private final String broadcastKey; + + public PredictOperator(String broadcastKey, int k, String featureCol) { + super(new AbstractRichFunction() {}); + this.topN = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public void processElement(StreamRecord<Row> streamRecord) { + Row value = streamRecord.getValue(); + output.collect(new StreamRecord<>(map(value))); + } + + private Row map(Row row) { + if (firstEle) { + loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey)); + firstEle = false; + } + DenseVector vector = (DenseVector) row.getField(featureCol); + Tuple2<List<Integer>, List<Double>> t2 = findNeighbor(vector, topN, modelData); + Row ret = new Row(row.getArity() + 1); + for (int i = 0; i < row.getArity(); ++i) { + ret.setField(i, row.getField(i)); + } + + ret.setField(row.getArity(), getResult(t2)); + return ret; + } + + /** + * Finds the nearest topN neighbors from whole nodes. + * + * @param input Input vector. + * @param topN Top N. + * @return Neighbors. + */ + private Tuple2<List<Integer>, List<Double>> findNeighbor( + DenseVector input, Integer topN, KnnModelData modelData) { + PriorityQueue<Tuple2<Double, Integer>> priorityQueue = + new PriorityQueue<>(modelData.getQueueComparator()); + search(input, topN, priorityQueue, modelData); + List<Integer> items = new ArrayList<>(); + List<Double> metrics = new ArrayList<>(); + while (!priorityQueue.isEmpty()) { + Tuple2<Double, Integer> result = priorityQueue.poll(); + items.add(result.f1); + metrics.add(result.f0); + } + Collections.reverse(items); + Collections.reverse(metrics); + priorityQueue.clear(); + return Tuple2.of(items, metrics); + } + + /** + * @param input Input vector. + * @param topN Top N. + * @param priorityQueue Priority queue. + */ + private void search( + DenseVector input, + Integer topN, + PriorityQueue<Tuple2<Double, Integer>> priorityQueue, + KnnModelData modelData) { + Tuple2<DenseVector, Double> sample = computeNorm(input); + Tuple2<Double, Integer> head = null; + for (int i = 0; i < modelData.getLength(); i++) { + List<Tuple2<Double, Integer>> values = computeDistance(sample, i); + for (Tuple2<Double, Integer> currentValue : values) { + head = updateQueue(priorityQueue, topN, currentValue, head); + } + } + } + + /** + * Updates queue. + * + * @param pq Queue. + * @param topN Top N. + * @param newValue New value. + * @param head Head value. + * @return Head value. + */ + private <T> Tuple2<Double, T> updateQueue( + PriorityQueue<Tuple2<Double, T>> pq, + int topN, + Tuple2<Double, T> newValue, + Tuple2<Double, T> head) { + if (pq.size() < topN) { + pq.add(Tuple2.of(newValue.f0, newValue.f1)); + head = pq.peek(); + } else { + if (pq.comparator().compare(head, newValue) < 0) { + Tuple2<Double, T> peek = pq.poll(); + assert peek != null; + peek.f0 = newValue.f0; + peek.f1 = newValue.f1; + pq.add(peek); + head = pq.peek(); + } + } + return head; + } + + /** + * Computes distance between sample and dictionary vectors. + * + * @param input Samples with l2 norm. + * @param index Dictionary vectors index. + * @return Distances. + */ + private List<Tuple2<Double, Integer>> computeDistance( + Tuple2<DenseVector, Double> input, Integer index) { + Tuple3<DenseMatrix, DenseVector, int[]> data = modelData.getDictData().get(index); + + DenseMatrix vectors = data.f0; + DenseMatrix distanceMatrix = new DenseMatrix(Objects.requireNonNull(vectors).numCols, 1); Review comment: Since one of the dimension of `distanceMatrix` is 1, would it be more intuitive to make it a `DenseVector` instead of `DenseMatrix`? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java ########## @@ -0,0 +1,71 @@ +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory; + +import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions; + +/** + * Column-major dense matrix. The entry values are stored in a single array of doubles with columns + * listed in sequence. + */ +@TypeInfo(DenseMatrixTypeInfoFactory.class) +public class DenseMatrix implements Matrix { + + /** Row dimension. */ + public final int numRows; + + /** Column dimension. */ + public final int numCols; + + /** + * Array for internal storage of elements. + * + * <p>The matrix data is stored in column major format internally. + */ + public final double[] values; + + /** + * Constructs an m-by-n matrix of zeros. + * + * @param numRows Number of rows. + * @param numCols Number of columns. + */ + public DenseMatrix(int numRows, int numCols) { + this(numRows, numCols, new double[numRows * numCols]); + } + + /** + * Constructs a matrix from a 1-D array. The data in the array should organize in column major. + * + * @param numRows Number of rows. + * @param numCols Number of cols. + * @param values One-dimensional array of doubles. + */ + public DenseMatrix(int numRows, int numCols, double[] values) { + Preconditions.checkArgument(values.length == numRows * numCols); + this.numRows = numRows; + this.numCols = numCols; + this.values = values; + } + + @Override + public int numRows() { + return numRows; + } + + @Override + public int numCols() { + return numCols; + } + + @Override + public double get(int i, int j) { + return values[numRows * j + i]; Review comment: Would it be useful to verify that `i` and `j` are in the expected range? Note that unlike `DenseVector`, here `numRows * j + i` could be in range even if `i` is out of range. It will be useful to throw exception in this case. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
