weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r765458856



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,378 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.AbstractRichFunction;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.PriorityQueue;
+
+/** Knn model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnModelParams<KnnModel> {
+    protected Map<Param<?>, Object> params = new HashMap<>();
+    private Table[] modelData;
+
+    /** Constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * Sets model data for knn prediction.
+     *
+     * @param modelData Knn model data.
+     * @return Knn model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * Gets model data.
+     *
+     * @return Table array including model data tables.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * Predicts label with knn model.
+     *
+     * @param inputs List of tables.
+     * @return Prediction result.
+     */
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+        final String broadcastKey = "broadcastModelKey";
+        String resultCols = getPredictionCol();
+        DataType resultTypes = DataTypes.INT();
+        ResolvedSchema outputSchema =
+                TableUtils.getOutputSchema(inputs[0].getResolvedSchema(), 
resultCols, resultTypes);
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastKey, model),
+                        inputList -> {
+                            DataStream inoutData = inputList.get(0);
+                            return inoutData.transform(
+                                    "mapFunc",
+                                    TableUtils.getRowTypeInfo(outputSchema),
+                                    new PredictOperator(broadcastKey, getK(), 
getFeaturesCol()));
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** @return Parameters for algorithm. */
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return this.params;
+    }
+
+    /**
+     * Saves model data.
+     *
+     * @param path Path to save.
+     */
+    @Override
+    public void save(String path) throws IOException {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData[0]).getTableEnvironment();
+
+        String dataPath = ReadWriteUtils.getDataPath(path);
+        FileSink<Row> sink =
+                FileSink.forRowFormat(new Path(dataPath), new 
KnnModelData.ModelDataEncoder())
+                        .withRollingPolicy(OnCheckpointRollingPolicy.build())
+                        .withBucketAssigner(new BasePathBucketAssigner<>())
+                        .build();
+        tEnv.toDataStream(modelData[0]).sinkTo(sink);
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    /**
+     * Loads model data from path.
+     *
+     * @param env Stream execution environment.
+     * @param path Model path.
+     * @return Knn model.
+     */
+    public static KnnModel load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        KnnModel retModel = ReadWriteUtils.loadStageParam(path);
+
+        Source<Row, ?, ?> source =
+                FileSource.forRecordStreamFormat(
+                                new KnnModelData.ModelDataStreamFormat(),
+                                ReadWriteUtils.getDataPaths(path))
+                        .build();
+        DataStream<Row> modelDataStream =
+                env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"data");
+        retModel.modelData =
+                new Table[] {tEnv.fromDataStream(modelDataStream, 
KnnModelData.getModelSchema())};
+        return retModel;
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, AbstractRichFunction>
+            implements OneInputStreamOperator<Row, Row> {
+
+        private boolean firstEle = true;
+        private final String featureCol;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private final String broadcastKey;
+
+        public PredictOperator(String broadcastKey, int k, String featureCol) {
+            super(new AbstractRichFunction() {});
+            this.topN = k;
+            this.broadcastKey = broadcastKey;
+            this.featureCol = featureCol;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(map(value)));
+        }
+
+        private Row map(Row row) {
+            if (firstEle) {
+                
loadModel(userFunction.getRuntimeContext().getBroadcastVariable(broadcastKey));
+                firstEle = false;
+            }
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            Tuple2<List<Integer>, List<Double>> t2 = findNeighbor(vector, 
topN, modelData);
+            Row ret = new Row(row.getArity() + 1);
+            for (int i = 0; i < row.getArity(); ++i) {
+                ret.setField(i, row.getField(i));
+            }
+
+            ret.setField(row.getArity(), getResult(t2));
+            return ret;
+        }
+
+        /**
+         * Finds the nearest topN neighbors from whole nodes.
+         *
+         * @param input Input vector.
+         * @param topN Top N.
+         * @return Neighbors.
+         */
+        private Tuple2<List<Integer>, List<Double>> findNeighbor(
+                DenseVector input, Integer topN, KnnModelData modelData) {
+            PriorityQueue<Tuple2<Double, Integer>> priorityQueue =
+                    new PriorityQueue<>(modelData.getQueueComparator());
+            search(input, topN, priorityQueue, modelData);
+            List<Integer> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Integer> result = priorityQueue.poll();
+                items.add(result.f1);
+                metrics.add(result.f0);
+            }
+            Collections.reverse(items);
+            Collections.reverse(metrics);
+            priorityQueue.clear();
+            return Tuple2.of(items, metrics);
+        }
+
+        /**
+         * @param input Input vector.
+         * @param topN Top N.
+         * @param priorityQueue Priority queue.
+         */
+        private void search(
+                DenseVector input,
+                Integer topN,
+                PriorityQueue<Tuple2<Double, Integer>> priorityQueue,
+                KnnModelData modelData) {
+            Tuple2<DenseVector, Double> sample = computeNorm(input);
+            Tuple2<Double, Integer> head = null;
+            for (int i = 0; i < modelData.getLength(); i++) {
+                List<Tuple2<Double, Integer>> values = computeDistance(sample, 
i);
+                for (Tuple2<Double, Integer> currentValue : values) {
+                    head = updateQueue(priorityQueue, topN, currentValue, 
head);
+                }
+            }
+        }
+
+        /**
+         * Updates queue.
+         *
+         * @param pq Queue.
+         * @param topN Top N.
+         * @param newValue New value.
+         * @param head Head value.
+         * @return Head value.
+         */
+        private <T> Tuple2<Double, T> updateQueue(
+                PriorityQueue<Tuple2<Double, T>> pq,
+                int topN,
+                Tuple2<Double, T> newValue,
+                Tuple2<Double, T> head) {
+            if (pq.size() < topN) {
+                pq.add(Tuple2.of(newValue.f0, newValue.f1));
+                head = pq.peek();
+            } else {
+                if (pq.comparator().compare(head, newValue) < 0) {
+                    Tuple2<Double, T> peek = pq.poll();
+                    assert peek != null;
+                    peek.f0 = newValue.f0;
+                    peek.f1 = newValue.f1;
+                    pq.add(peek);
+                    head = pq.peek();
+                }
+            }
+            return head;
+        }
+
+        /**
+         * Computes distance between sample and dictionary vectors.
+         *
+         * @param input Samples with l2 norm.
+         * @param index Dictionary vectors index.
+         * @return Distances.
+         */
+        private List<Tuple2<Double, Integer>> computeDistance(
+                Tuple2<DenseVector, Double> input, Integer index) {
+            Tuple3<DenseMatrix, DenseVector, int[]> data = 
modelData.getDictData().get(index);
+
+            DenseMatrix vectors = data.f0;
+            DenseMatrix distanceMatrix = new 
DenseMatrix(Objects.requireNonNull(vectors).numCols, 1);
+
+            DenseVector norm = data.f1;
+            double[] normL2Square = Objects.requireNonNull(norm).values;
+            BLAS.gemv(-2.0, vectors, true, input.f0, 0.0, new 
DenseVector(distanceMatrix.values));
+            for (int i = 0; i < distanceMatrix.values.length; i++) {
+                distanceMatrix.values[i] = 
Math.sqrt(Math.abs(distanceMatrix.values[i] + input.f1 + normL2Square[i]));
+            }
+
+            List<Tuple2<Double, Integer>> list = new ArrayList<>(0);

Review comment:
       OK




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to