weibozhao commented on a change in pull request #24: URL: https://github.com/apache/flink-ml/pull/24#discussion_r764497065
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java ########## @@ -0,0 +1,217 @@ +package org.apache.flink.ml.classification.knn; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseMatrix; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * An Estimator which implements the KNN algorithm. KNN is to classify unlabeled observations by + * assigning them to the class of the most similar labeled examples. + * + * <p>See: https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm. + */ +public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> { + + protected Map<Param<?>, Object> params = new HashMap<>(); + + /** Constructor. */ + public Knn() { + ParamUtils.initializeMapWithDefaultValues(params, this); + } + + /** + * Fits data and produces knn model. + * + * @param inputs A list of tables, including train data table. + * @return Knn model. + */ + @Override + public KnnModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + String labelCol = getLabelCol(); + + DataStream<Tuple2<DenseVector, Integer>> trainData = + input.map( + new MapFunction<Row, Tuple2<DenseVector, Integer>>() { + @Override + public Tuple2<DenseVector, Integer> map(Row value) { + Integer label = (Integer) value.getField(labelCol); + DenseVector vec = (DenseVector) value.getField(getFeaturesCol()); + return Tuple2.of(vec, label); + } + }); + + DataStream<Row> model = buildModel(trainData); + KnnModel knnModel = + new KnnModel() + .setFeaturesCol(getFeaturesCol()) + .setK(getK()) + .setPredictionCol(getPredictionCol()); + knnModel.setModelData(tEnv.fromDataStream(model, KnnModelData.getModelSchema())); + return knnModel; + } + + /** + * Builds knn model. + * + * @param dataStream Input data. + * @return Knn model. + */ + private static DataStream<Row> buildModel(DataStream<Tuple2<DenseVector, Integer>> dataStream) { + Schema schema = KnnModelData.getModelSchema(); + return dataStream.transform( + "build knn model", + TableUtils.getRowTypeInfo(schema), + new MapPartitionFunctionWrapper<>( + new RichMapPartitionFunction<Tuple2<DenseVector, Integer>, Row>() { + @Override + public void mapPartition( + Iterable<Tuple2<DenseVector, Integer>> values, + Collector<Row> out) { + List<Tuple3<DenseMatrix, DenseVector, Integer[]>> list = + prepareMatrixData(values); + for (Tuple3<DenseMatrix, DenseVector, Integer[]> t3 : list) { + Row ret = new Row(3); + ret.setField(0, t3.f0); + ret.setField(1, t3.f1); + ret.setField(2, t3.f2); + out.collect(ret); + } + } + })); + } + + /** + * Prepares matrix data, the output is a list of Tuple3, which includes vectors, vecNorms and + * label. + * + * @param trainData Input train data. + * @return Model data in format of list tuple3. + */ + private static List<Tuple3<DenseMatrix, DenseVector, Integer[]>> prepareMatrixData( + Iterable<Tuple2<DenseVector, Integer>> trainData) { + final int size = 5 * 1024 * 1024; Review comment: L2 cache in computer architecture is considered when we choose the size. Then we can decrease cache miss rate. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
