lindong28 commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r767490529
########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseMatrix.java ########## @@ -0,0 +1,69 @@ +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixTypeInfoFactory; + +import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions; Review comment: Can we replace this with `org.apache.flink.util.Preconditions`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java ########## @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link Knn} and {@link KnnModel}. */ +public class KnnTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainData; + private Table validData; Review comment: nits: it is not very clear what is `validData`. Would `predictData` be a bit more intuitive here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java ########## @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseMatrixSerializer; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link KnnModel}. + * + * <p>This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class KnnModelData { Review comment: @gaoyunhaii is this OK to use `KnnModelData` as the DataStream element type when `KnnModelData` is not serializable? If we use `KnnModelData` as DataStream element type, would Flink automatically use `DenseMatrixSerializer` and `DenseVectorSerializer` to serialize/de-serialize its member variables? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { Review comment: Hmm.. why we do allow user to provide an input row without the `featureCol`? Would it be better to throw exception here? And could we rename this variable as `feature` to be consistent with the `featureCol`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); Review comment: It seems that only the first element of this tuple2 is used. Could we simplify `findNeighbor(...)` to return just a `List<Double>`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { Review comment: nits: could we rename `Double obj` to `Double label` to make the code a bit more self-explanatory? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the KNN algorithm. + * + * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + protected Map<Param<?>, Object> params = new HashMap<>(); + + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + /* Tuple2 : <sampleVector, label> */ + DataStream<Tuple2<DenseVector, Double>> inputData = + tEnv.toDataStream(inputs[0]) + .map( + new MapFunction<Row, Tuple2<DenseVector, Double>>() { + @Override + public Tuple2<DenseVector, Double> map(Row value) { + Double label = (Double) value.getField(getLabelCol()); + DenseVector feature = + (DenseVector) value.getField(getFeaturesCol()); + return Tuple2.of(feature, label); + } + }); + DataStream<KnnModelData> distributedModelData = prepareModelData(inputData); + DataStream<KnnModelData> modelData = mergeModelData(distributedModelData); + KnnModel model = new KnnModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Knn load(StreamExecutionEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Prepares distributed knn model data. Constructs the sample matrix and computes norm of + * vectors. + * + * @param inputData Input vector data with label. + * @return Distributed knn model. + */ + private static DataStream<KnnModelData> prepareModelData( + DataStream<Tuple2<DenseVector, Double>> inputData) { + return DataStreamUtils.mapPartition( + inputData, + new RichMapPartitionFunction<Tuple2<DenseVector, Double>, KnnModelData>() { + @Override + public void mapPartition( + Iterable<Tuple2<DenseVector, Double>> values, + Collector<KnnModelData> out) { + List<Tuple2<DenseVector, Double>> dataPoints = new ArrayList<>(0); + for (Tuple2<DenseVector, Double> tuple2 : values) { + dataPoints.add(tuple2); + } + int featureDim = dataPoints.get(0).f0.size(); + DenseMatrix packedFeatures = new DenseMatrix(featureDim, dataPoints.size()); + DenseVector labels = new DenseVector(dataPoints.size()); + for (int i = 0; i < dataPoints.size(); ++i) { + Tuple2<DenseVector, Double> tuple2 = dataPoints.get(i); + labels.values[i] = tuple2.f1; + double[] vectorData = tuple2.f0.values; + double[] matrixData = packedFeatures.values; + System.arraycopy(vectorData, 0, matrixData, i * featureDim, featureDim); + } + DenseVector featureNorms = computeNorm(packedFeatures); + if (dataPoints.size() > 0) { + out.collect(new KnnModelData(packedFeatures, featureNorms, labels)); + } + } + }, + TypeInformation.of(KnnModelData.class)); + } + + /** + * Merges knn model data. + * + * @param distributedModelData Distributed knn model data. + * @return Knn model. + */ + private static DataStream<KnnModelData> mergeModelData( + DataStream<KnnModelData> distributedModelData) { + distributedModelData.getTransformation().setParallelism(1); + return DataStreamUtils.mapPartition( + distributedModelData, + new RichMapPartitionFunction<KnnModelData, KnnModelData>() { + @Override + public void mapPartition( + Iterable<KnnModelData> values, Collector<KnnModelData> out) { + List<KnnModelData> bufferKnnModelData = new ArrayList<>(1); Review comment: nits: would it be slightly better to use `bufferedKnnModelData`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { + detail.merge(obj, percent, Double::sum); + } + double max = 0.0; + double prediction = 0.0; + for (Map.Entry<Double, Double> entry : detail.entrySet()) { + if (entry.getValue() > max) { + max = entry.getValue(); + prediction = entry.getKey(); + } + } + return Row.join(row, Row.of(prediction)); + } + + /** Finds the nearest k neighbors from whole vectors in matrix. */ + private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector input) { + PriorityQueue<Tuple2<Double, Double>> priorityQueue = + new PriorityQueue<>(Comparator.comparingDouble(o -> -o.f0)); + double d = 0.0; + for (int i = 0; i < input.size(); ++i) { + d += input.values[i] * input.values[i]; + } + Tuple2<DenseVector, Double> sample = Tuple2.of(input, d); + DenseMatrix packedFeatures = knnModelData.packedFeatures; + double[] labelValues = knnModelData.labels.values; + double[] normValues = knnModelData.featureNorms.values; + DenseVector distanceVector = new DenseVector(packedFeatures.numCols()); + BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, distanceVector); + for (int i = 0; i < distanceVector.values.length; i++) { + distanceVector.values[i] = + Math.sqrt(Math.abs(distanceVector.values[i] + sample.f1 + normValues[i])); + } + List<Tuple2<Double, Double>> distances = new ArrayList<>(labelValues.length); + for (int i = 0; i < labelValues.length; i++) { + Tuple2<Double, Double> tuple2 = Tuple2.of(distanceVector.values[i], labelValues[i]); + distances.add(tuple2); + } + Tuple2<Double, Double> head = null; + for (Tuple2<Double, Double> currentValue : distances) { Review comment: nits: could we rename `currentValue` to be something more meaningful and consistent with `distance`, e.g. `distanceAndLabel`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { + detail.merge(obj, percent, Double::sum); + } + double max = 0.0; + double prediction = 0.0; + for (Map.Entry<Double, Double> entry : detail.entrySet()) { + if (entry.getValue() > max) { + max = entry.getValue(); + prediction = entry.getKey(); + } + } + return Row.join(row, Row.of(prediction)); + } + + /** Finds the nearest k neighbors from whole vectors in matrix. */ + private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector input) { + PriorityQueue<Tuple2<Double, Double>> priorityQueue = + new PriorityQueue<>(Comparator.comparingDouble(o -> -o.f0)); + double d = 0.0; + for (int i = 0; i < input.size(); ++i) { + d += input.values[i] * input.values[i]; + } + Tuple2<DenseVector, Double> sample = Tuple2.of(input, d); + DenseMatrix packedFeatures = knnModelData.packedFeatures; + double[] labelValues = knnModelData.labels.values; + double[] normValues = knnModelData.featureNorms.values; + DenseVector distanceVector = new DenseVector(packedFeatures.numCols()); + BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, distanceVector); + for (int i = 0; i < distanceVector.values.length; i++) { + distanceVector.values[i] = + Math.sqrt(Math.abs(distanceVector.values[i] + sample.f1 + normValues[i])); + } + List<Tuple2<Double, Double>> distances = new ArrayList<>(labelValues.length); + for (int i = 0; i < labelValues.length; i++) { + Tuple2<Double, Double> tuple2 = Tuple2.of(distanceVector.values[i], labelValues[i]); + distances.add(tuple2); + } + Tuple2<Double, Double> head = null; Review comment: nits: Instead of maintaining this variable, would it be simpler (with the same efficiency) to use `priorityQueue.peek()` when ever it is needed? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { + detail.merge(obj, percent, Double::sum); + } + double max = 0.0; + double prediction = 0.0; + for (Map.Entry<Double, Double> entry : detail.entrySet()) { + if (entry.getValue() > max) { + max = entry.getValue(); + prediction = entry.getKey(); + } + } + return Row.join(row, Row.of(prediction)); + } + + /** Finds the nearest k neighbors from whole vectors in matrix. */ + private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector input) { + PriorityQueue<Tuple2<Double, Double>> priorityQueue = + new PriorityQueue<>(Comparator.comparingDouble(o -> -o.f0)); + double d = 0.0; + for (int i = 0; i < input.size(); ++i) { + d += input.values[i] * input.values[i]; + } + Tuple2<DenseVector, Double> sample = Tuple2.of(input, d); + DenseMatrix packedFeatures = knnModelData.packedFeatures; + double[] labelValues = knnModelData.labels.values; + double[] normValues = knnModelData.featureNorms.values; + DenseVector distanceVector = new DenseVector(packedFeatures.numCols()); + BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, distanceVector); + for (int i = 0; i < distanceVector.values.length; i++) { + distanceVector.values[i] = + Math.sqrt(Math.abs(distanceVector.values[i] + sample.f1 + normValues[i])); + } + List<Tuple2<Double, Double>> distances = new ArrayList<>(labelValues.length); Review comment: nits: could we give it a more meaningful name to improve readability, e.g. `distanceAndLabels`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { + detail.merge(obj, percent, Double::sum); + } + double max = 0.0; + double prediction = 0.0; + for (Map.Entry<Double, Double> entry : detail.entrySet()) { + if (entry.getValue() > max) { + max = entry.getValue(); + prediction = entry.getKey(); + } + } + return Row.join(row, Row.of(prediction)); + } + + /** Finds the nearest k neighbors from whole vectors in matrix. */ + private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector input) { + PriorityQueue<Tuple2<Double, Double>> priorityQueue = + new PriorityQueue<>(Comparator.comparingDouble(o -> -o.f0)); + double d = 0.0; + for (int i = 0; i < input.size(); ++i) { + d += input.values[i] * input.values[i]; + } + Tuple2<DenseVector, Double> sample = Tuple2.of(input, d); + DenseMatrix packedFeatures = knnModelData.packedFeatures; + double[] labelValues = knnModelData.labels.values; + double[] normValues = knnModelData.featureNorms.values; + DenseVector distanceVector = new DenseVector(packedFeatures.numCols()); + BLAS.gemv(-2.0, packedFeatures, true, sample.f0, 0.0, distanceVector); + for (int i = 0; i < distanceVector.values.length; i++) { + distanceVector.values[i] = + Math.sqrt(Math.abs(distanceVector.values[i] + sample.f1 + normValues[i])); + } + List<Tuple2<Double, Double>> distances = new ArrayList<>(labelValues.length); + for (int i = 0; i < labelValues.length; i++) { + Tuple2<Double, Double> tuple2 = Tuple2.of(distanceVector.values[i], labelValues[i]); + distances.add(tuple2); + } + Tuple2<Double, Double> head = null; + for (Tuple2<Double, Double> currentValue : distances) { + if (priorityQueue.size() < k) { + priorityQueue.add(Tuple2.of(currentValue.f0, currentValue.f1)); + head = priorityQueue.peek(); + } else { + if (priorityQueue.comparator().compare(head, currentValue) < 0) { + Tuple2<Double, Double> peek = priorityQueue.poll(); + if (peek != null) { + peek.f0 = currentValue.f0; + peek.f1 = currentValue.f1; + priorityQueue.add(peek); + head = priorityQueue.peek(); + } + } + } + } + List<Double> items = new ArrayList<>(); + List<Double> metrics = new ArrayList<>(); Review comment: nits: could we make these variable names more consistent with the above code and also more meaningful? Maybe something like `labels` and `distances`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { + detail.merge(obj, percent, Double::sum); + } + double max = 0.0; + double prediction = 0.0; + for (Map.Entry<Double, Double> entry : detail.entrySet()) { + if (entry.getValue() > max) { + max = entry.getValue(); + prediction = entry.getKey(); + } + } + return Row.join(row, Row.of(prediction)); + } + + /** Finds the nearest k neighbors from whole vectors in matrix. */ + private Tuple2<List<Double>, List<Double>> findNeighbor(DenseVector input) { + PriorityQueue<Tuple2<Double, Double>> priorityQueue = Review comment: nits: would it be slightly better to rename this as `nearestKNeighbors`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); + for (Double obj : tuple2.f0) { + detail.merge(obj, percent, Double::sum); + } + double max = 0.0; Review comment: Would we rename this variable to be more self-explanation, e.g. `maxWeight`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java ########## @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.classification.knn.KnnModelData.ModelDataDecoder; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; + +/** A Model which classifies data using the model data computed by {@link Knn}. */ +public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> { + protected Map<Param<?>, Object> params = new HashMap<>(); + private Table modelDataTable; + + public KnnModel() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel setModelData(Table... modelData) { + this.modelDataTable = modelData[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<KnnModelData> knnModel = KnnModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), BasicTypeInfo.DOUBLE_TYPE_INFO), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, knnModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictLabelFunction( + broadcastModelKey, getK(), getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + KnnModelData.getModelDataStream(modelDataTable), + path, + new KnnModelData.ModelDataEncoder(), + 0); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return Knn model. + */ + public static KnnModel load(StreamExecutionEnvironment env, String path) throws IOException { + KnnModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = ReadWriteUtils.loadModelData(env, path, new ModelDataDecoder(), 0); + return model.setModelData(modelDataTable); + } + + /** This operator loads model data and predicts result. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private KnnModelData knnModelData; + private final int k; + private final String broadcastKey; + + public PredictLabelFunction(String broadcastKey, int k, String featureCol) { + this.k = k; + this.broadcastKey = broadcastKey; + this.featureCol = featureCol; + } + + @Override + public Row map(Row row) { + if (knnModelData == null) { + knnModelData = + (KnnModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + } + DenseVector vector = (DenseVector) row.getField(featureCol); + if (vector == null) { + return Row.join(row, new Row(1)); + } + Tuple2<List<Double>, List<Double>> tuple2 = findNeighbor(vector); + double percent = 1.0 / tuple2.f0.size(); + Map<Double, Double> detail = new HashMap<>(0); Review comment: nits: Would we make this variable name more self-explanatory, e.g. `labelWeights`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the KNN algorithm. + * + * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + protected Map<Param<?>, Object> params = new HashMap<>(); + + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + @Override + public KnnModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + /* Tuple2 : <sampleVector, label> */ + DataStream<Tuple2<DenseVector, Double>> inputData = + tEnv.toDataStream(inputs[0]) + .map( + new MapFunction<Row, Tuple2<DenseVector, Double>>() { + @Override + public Tuple2<DenseVector, Double> map(Row value) { + Double label = (Double) value.getField(getLabelCol()); + DenseVector feature = + (DenseVector) value.getField(getFeaturesCol()); + return Tuple2.of(feature, label); + } + }); + DataStream<KnnModelData> distributedModelData = prepareModelData(inputData); + DataStream<KnnModelData> modelData = mergeModelData(distributedModelData); + KnnModel model = new KnnModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return this.params; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Knn load(StreamExecutionEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + /** + * Prepares distributed knn model data. Constructs the sample matrix and computes norm of + * vectors. + * + * @param inputData Input vector data with label. + * @return Distributed knn model. + */ + private static DataStream<KnnModelData> prepareModelData( + DataStream<Tuple2<DenseVector, Double>> inputData) { + return DataStreamUtils.mapPartition( + inputData, + new RichMapPartitionFunction<Tuple2<DenseVector, Double>, KnnModelData>() { + @Override + public void mapPartition( + Iterable<Tuple2<DenseVector, Double>> values, + Collector<KnnModelData> out) { + List<Tuple2<DenseVector, Double>> dataPoints = new ArrayList<>(0); + for (Tuple2<DenseVector, Double> tuple2 : values) { + dataPoints.add(tuple2); + } + int featureDim = dataPoints.get(0).f0.size(); + DenseMatrix packedFeatures = new DenseMatrix(featureDim, dataPoints.size()); + DenseVector labels = new DenseVector(dataPoints.size()); + for (int i = 0; i < dataPoints.size(); ++i) { + Tuple2<DenseVector, Double> tuple2 = dataPoints.get(i); + labels.values[i] = tuple2.f1; + System.arraycopy( + tuple2.f0.values, + 0, + packedFeatures.values, + i * featureDim, + featureDim); + } + DenseVector featureNorms = computeNorm(packedFeatures); + if (dataPoints.size() > 0) { + out.collect(new KnnModelData(packedFeatures, featureNorms, labels)); + } + } + }, + TypeInformation.of(KnnModelData.class)); Review comment: It seems that we can simplify the code here by calling `mapPartition(...)` without specifying this TypeInformation. Could you double check this? Same for `mergeModelData(...)`. ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java ########## @@ -26,9 +26,91 @@ private static final dev.ludovic.netlib.BLAS JAVA_BLAS = dev.ludovic.netlib.JavaBLAS.getInstance(); - /** y += a * x . */ + /** + * \sum_i |x_i| . + * + * @param x x + * @return \sum_i |x_i| + */ + public static double asum(DenseVector x) { + return JAVA_BLAS.dasum(x.size(), x.values, 0, 1); + } + + /** + * y += a * x . + * + * @param a a + * @param x x + * @param y y + */ public static void axpy(double a, DenseVector x, DenseVector y) { Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1); } + + /** + * x \cdot y . + * + * @param x x + * @param y y + * @return x \cdot y + */ + public static double dot(DenseVector x, DenseVector y) { + Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); + return JAVA_BLAS.ddot(x.size(), x.values, 1, y.values, 1); + } + + /** + * \sqrt(\sum_i x_i * x_i) . + * + * @param x x + * @return \sqrt(\sum_i x_i * x_i) + */ + public static double norm2(DenseVector x) { + return JAVA_BLAS.dnrm2(x.size(), x.values, 1); + } + + /** + * x = x * a . + * + * @param a a + * @param x x + */ + public static void scal(double a, DenseVector x) { + JAVA_BLAS.dscal(x.size(), a, x.values, 1); + } + + /** + * y = alpha * A * x + beta * y or y = alpha * (A^T) * x + beta * y. + * + * @param alpha alpha. + * @param A m x n matrix A. + * @param transA Whether transposes matrix y before multiply. + * @param x dense vector with size n. + * @param beta beta. + * @param y dense vector with size m. + */ + public static void gemv( + double alpha, + DenseMatrix A, Review comment: It seems that this line cause the style check failure? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org