weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755682429



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,594 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseMatrix;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.lang3.ArrayUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.io.IOException;
+import java.lang.reflect.Type;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+import java.util.function.Function;
+
+/** Knn classification model fitted by estimator. */
+public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> {
+
+    private static final long serialVersionUID = 1303892137143865652L;
+
+    private static final String BROADCAST_STR = "broadcastModelKey";
+    private static final int FASTDISTANCE_TYPE_INDEX = 0;
+    private static final int DATA_INDEX = 1;
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    private Table[] modelData;
+
+    /** constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public KnnModel(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * Set model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * get model data.
+     *
+     * @return list of tables.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(BROADCAST_STR, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+        ResolvedSchema outputSchema =
+                getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+        DataType[] dataTypes = outputSchema.getColumnDataTypes().toArray(new 
DataType[0]);
+        TypeInformation<?>[] typeInformations = new 
TypeInformation[dataTypes.length];
+
+        for (int i = 0; i < dataTypes.length; ++i) {
+            typeInformations[i] = 
TypeInformation.of(dataTypes[i].getLogicalType().getClass());
+        }
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            new RowTypeInfo(
+                                    typeInformations,
+                                    outputSchema.getColumnNames().toArray(new 
String[0])),
+                            new PredictOperator(
+                                    new KnnRichFunction(
+                                            getParamMap(), 
inputs[0].getResolvedSchema())));
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+        return new Table[] {tEnv.fromDataStream(output, 
resolvedSchema2Schema(outputSchema))};
+    }
+
+    /**
+     * transform resolvedSchema to schema.
+     *
+     * @param resolvedSchema input resolvedSchema.
+     * @return output schema.
+     */
+    private static Schema resolvedSchema2Schema(ResolvedSchema resolvedSchema) 
{
+        Schema.Builder builder = Schema.newBuilder();
+        List<String> colNames = resolvedSchema.getColumnNames();
+        List<DataType> colTypes = resolvedSchema.getColumnDataTypes();
+        for (int i = 0; i < colNames.size(); ++i) {
+            builder.column(colNames.get(i), 
colTypes.get(i).getLogicalType().toString());
+        }
+        return builder.build();
+    }
+
+    private static class KnnRichFunction extends RichMapFunction<Row, Row> {
+        private boolean firstEle = true;
+        private final String[] reservedCols;
+        private String[] selectedCols;
+        private String vectorCol;
+        private DataType idType;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private Map<String, Object> meta;
+
+        public KnnRichFunction(Map<Param<?>, Object> params, ResolvedSchema 
dataSchema) {
+            reservedCols = dataSchema.getColumnNames().toArray(new String[0]);
+            this.topN = (Integer) params.get(KnnParams.K);
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (firstEle) {
+                
loadModel(getRuntimeContext().getBroadcastVariable(BROADCAST_STR));
+                firstEle = false;
+            }
+            DenseVector vector;
+            if (null != selectedCols) {
+                vector = new DenseVector(new double[selectedCols.length]);
+                for (int i = 0; i < selectedCols.length; i++) {
+                    Preconditions.checkNotNull(
+                            row.getField(selectedCols[i]), "There is NULL in 
featureCols!");
+                    vector.set(i, ((Number) 
row.getField(selectedCols[i])).doubleValue());
+                }
+            } else {
+                vector = 
DenseVector.fromString(row.getField(vectorCol).toString());
+            }
+            String s = findNeighbor(vector, topN, modelData).toLowerCase();
+
+            Row ret = new Row(reservedCols.length + 1);
+            for (int i = 0; i < reservedCols.length; ++i) {
+                ret.setField(i, row.getField(reservedCols[i]));
+            }
+
+            Tuple2<Object, String> tuple2 = getResultFormat(extractObject(s, 
idType));
+            ret.setField(reservedCols.length, tuple2.f0);
+            return ret;
+        }
+
+        /**
+         * find the nearest topN neighbors from whole nodes.
+         *
+         * @param input input node.
+         * @param topN top N.
+         * @return neighbor.
+         */
+        private String findNeighbor(Object input, Integer topN, KnnModelData 
modelData) {
+            PriorityQueue<Tuple2<Double, Object>> priorityQueue =
+                    new PriorityQueue<>(modelData.getQueueComparator());
+            search(input, topN, priorityQueue, modelData);
+            List<Object> items = new ArrayList<>();
+            List<Double> metrics = new ArrayList<>();
+            while (!priorityQueue.isEmpty()) {
+                Tuple2<Double, Object> result = priorityQueue.poll();
+                items.add(castTo(result.f1, idType));
+                metrics.add(result.f0);
+            }
+            Collections.reverse(items);
+            Collections.reverse(metrics);
+            priorityQueue.clear();
+            return serializeResult(items, ImmutableMap.of("METRIC", metrics));
+        }
+
+        /**
+         * serialize result to json format.
+         *
+         * @param objectValue the nearest nodes found.
+         * @param others the metric of nodes.
+         * @return serialize result.
+         */
+        private String serializeResult(List<Object> objectValue, Map<String, 
List<Double>> others) {
+            final String id = "ID";
+            Map<String, String> result =
+                    new TreeMap<>(
+                            (o1, o2) -> {
+                                if (id.equals(o1) && id.equals(o2)) {
+                                    return 0;
+                                } else if (id.equals(o1)) {
+                                    return -1;
+                                } else if (id.equals(o2)) {
+                                    return 1;
+                                }
+
+                                return o1.compareTo(o2);
+                            });
+
+            result.put(id, ReadWriteUtils.OBJECT_MAPPER.toJson(objectValue));
+
+            if (others != null) {
+                for (Map.Entry<String, List<Double>> other : 
others.entrySet()) {
+                    result.put(other.getKey(), 
ReadWriteUtils.OBJECT_MAPPER.toJson(other.getValue()));
+                }
+            }
+            return ReadWriteUtils.OBJECT_MAPPER.toJson(result);
+        }
+
+        /**
+         * @param input input node.
+         * @param topN top N.
+         * @param priorityQueue priority queue.
+         */
+        private void search(
+                Object input,
+                Integer topN,
+                PriorityQueue<Tuple2<Double, Object>> priorityQueue,
+                KnnModelData modelData) {
+            Object sample = prepareSample(input, modelData);
+            Tuple2<Double, Object> head = null;
+            for (int i = 0; i < modelData.getLength(); i++) {
+                ArrayList<Tuple2<Double, Object>> values = 
computeDistance(sample, i);
+                if (null == values || values.size() == 0) {
+                    continue;
+                }
+                for (Tuple2<Double, Object> currentValue : values) {
+                    if (null == topN) {
+                        priorityQueue.add(Tuple2.of(currentValue.f0, 
currentValue.f1));
+                    } else {
+                        head = updateQueue(priorityQueue, topN, currentValue, 
head);
+                    }
+                }
+            }
+        }
+
+        /**
+         * update queue.
+         *
+         * @param map queue.
+         * @param topN top N.
+         * @param newValue new value.
+         * @param head head value.
+         * @param <T> id type.
+         * @return head value.
+         */
+        private <T> Tuple2<Double, T> updateQueue(
+                PriorityQueue<Tuple2<Double, T>> map,
+                int topN,
+                Tuple2<Double, T> newValue,
+                Tuple2<Double, T> head) {
+            if (null == newValue) {
+                return head;
+            }
+            if (map.size() < topN) {
+                map.add(Tuple2.of(newValue.f0, newValue.f1));
+                head = map.peek();
+            } else {
+                if (map.comparator().compare(head, newValue) < 0) {
+                    Tuple2<Double, T> peek = map.poll();
+                    peek.f0 = newValue.f0;
+                    peek.f1 = newValue.f1;
+                    map.add(peek);
+                    head = map.peek();
+                }
+            }
+            return head;
+        }
+
+        /**
+         * prepare sample.
+         *
+         * @param input sample to parse.
+         * @return
+         */
+        private Object prepareSample(Object input, KnnModelData modelData) {
+            return modelData
+                    .getFastDistance()
+                    
.prepareVectorData(Tuple2.of(DenseVector.fromString(input.toString()), null));
+        }
+
+        private ArrayList<Tuple2<Double, Object>> computeDistance(Object 
input, Integer index) {
+            FastDistanceMatrixData data = modelData.getDictData().get(index);
+            DenseMatrix res =
+                    modelData.getFastDistance().calc((FastDistanceVectorData) 
input, data);
+            ArrayList<Tuple2<Double, Object>> list = new ArrayList<>(0);
+            Row[] curRows = data.getRows();
+            for (int i = 0; i < data.getRows().length; i++) {
+                Tuple2<Double, Object> tuple = Tuple2.of(res.getData()[i], 
curRows[i].getField(0));
+                list.add(tuple);
+            }
+            return list;
+        }
+
+        /**
+         * get output format of knn predict result.
+         *
+         * @param tuple initial result from knn predictor.
+         * @return output format result.
+         */
+        private Tuple2<Object, String> getResultFormat(Tuple2<List<Object>, 
List<Object>> tuple) {
+            double percent = 1.0 / tuple.f0.size();
+            Map<Object, Double> detail = new HashMap<>(0);
+
+            for (Object obj : tuple.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+
+            double max = 0.0;
+            Object prediction = null;
+
+            for (Map.Entry<Object, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+
+            return Tuple2.of(prediction, 
ReadWriteUtils.OBJECT_MAPPER.toJson(detail));
+        }
+
+        /**
+         * @param json json format result of knn prediction.
+         * @param idType id type.
+         * @return List format result.
+         */
+        private Tuple2<List<Object>, List<Object>> extractObject(String json, 
DataType idType) {
+            Map<String, String> deserializedJson;
+            try {
+                deserializedJson =
+                        ReadWriteUtils.OBJECT_MAPPER.fromJson(json, new 
TypeReference<Map<String, String>>() {}.getType());
+            } catch (Exception e) {
+                throw new IllegalStateException(
+                        "Fail to deserialize json '" + json + "', please check 
the input!");
+            }
+
+            Map<String, String> lowerCaseDeserializedJson = new HashMap<>(0);
+
+            for (Map.Entry<String, String> entry : 
deserializedJson.entrySet()) {
+                lowerCaseDeserializedJson.put(
+                        entry.getKey().trim().toLowerCase(), entry.getValue());
+            }
+
+            Map<String, List<Object>> map = new HashMap<>(2);
+
+            Type type = idType.getLogicalType().getDefaultConversion();
+            String ids = lowerCaseDeserializedJson.get("id");
+            String metric = lowerCaseDeserializedJson.get("metric");
+            if (ids == null) {
+                map.put("id", null);
+            } else {
+                map.put(
+                        "id",
+                        ReadWriteUtils.OBJECT_MAPPER.fromJson(
+                                ids,
+                                ParameterizedTypeImpl.make(List.class, new 
Type[] {type}, null)));
+            }
+
+            if (ids == null) {
+                map.put("metric", null);
+            } else {
+                map.put(
+                        "metric",
+                        ReadWriteUtils.OBJECT_MAPPER.fromJson(
+                                metric,
+                                ParameterizedTypeImpl.make(
+                                        List.class, new Type[] {Double.class}, 
null)));
+            }
+            return Tuple2.of(map.get("id"), map.get("metric"));
+        }
+
+        private void loadModel(List<Object> broadcastVar) {
+            List<FastDistanceMatrixData> dictData = new ArrayList<>();
+            for (Object obj : broadcastVar) {
+                Row row = (Row) obj;
+                if (row.getField(row.getArity() - 2) != null) {
+                    meta = ReadWriteUtils.OBJECT_MAPPER.fromJson((String) 
row.getField(row.getArity() - 2), HashMap.class);
+                }
+            }
+            for (Object obj : broadcastVar) {
+                Row row = (Row) obj;
+                if (row.getField(FASTDISTANCE_TYPE_INDEX) != null) {
+                    long type = (long) row.getField(FASTDISTANCE_TYPE_INDEX);
+                    if (type == 1L) {
+                        dictData.add(
+                                FastDistanceMatrixData.fromString(
+                                        (String) row.getField(DATA_INDEX)));
+                    }
+                }
+            }
+            if (meta.containsKey(KnnParams.FEATURE_COLS.name)) {
+                selectedCols =
+                        ReadWriteUtils.OBJECT_MAPPER.fromJson(
+                                (String) 
meta.get(KnnParams.FEATURE_COLS.name), String[].class);
+            } else {
+                vectorCol =
+                        ReadWriteUtils.OBJECT_MAPPER.fromJson((String) 
meta.get(KnnParams.VECTOR_COL.name), String.class);
+            }
+
+            modelData = new KnnModelData(dictData, new EuclideanDistance());
+            idType =
+                    LogicalTypeDataTypeConverter.toDataType(
+                            LogicalTypeParser.parse((String) 
this.meta.get("idType")));
+        }
+    }
+
+    /**
+     * this operator use mapper to load the model data and do the prediction. 
if you want to write a
+     * prediction operator, you need implement a special mapper for this 
operator.
+     */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, RichMapFunction<Row, Row>>
+            implements OneInputStreamOperator<Row, Row> {
+
+        public PredictOperator(RichMapFunction<Row, Row> userFunction) {
+            super(userFunction);
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(userFunction.map(value)));
+        }
+    }
+
+    private ResolvedSchema getOutputSchema(
+            ResolvedSchema dataSchema, Map<Param<?>, Object> params, DataType 
idType) {
+        String[] reservedCols = dataSchema.getColumnNames().toArray(new 
String[0]);
+        DataType[] reservedTypes = dataSchema.getColumnDataTypes().toArray(new 
DataType[0]);
+        String[] resultCols = new String[] {(String) 
params.get(KnnParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+        return ResolvedSchema.physical(
+                ArrayUtils.addAll(reservedCols, resultCols),
+                ArrayUtils.addAll(reservedTypes, resultTypes));
+    }
+
+    /**
+     * cast data x to t type.
+     *
+     * @param x data.
+     * @param t type.
+     * @return
+     */
+    private static Object castTo(Object x, DataType t) {
+        if (x == null) {
+            return null;
+        } else if (t.equals(DataTypes.BOOLEAN())) {
+            if (x instanceof Boolean) {
+                return x;
+            }
+            return Boolean.valueOf(x.toString());
+        } else if (t.equals(DataTypes.BYTES())) {
+            if (x instanceof Number) {
+                return ((Number) x).byteValue();
+            }
+            return Byte.valueOf(x.toString());
+        } else if (t.equals(DataTypes.INT())) {
+            if (x instanceof Number) {
+                return ((Number) x).intValue();
+            }
+            return Integer.valueOf(x.toString());
+        } else if (t.equals(DataTypes.BIGINT())) {
+            if (x instanceof Number) {
+                return ((Number) x).longValue();
+            }
+            return Long.valueOf(x.toString());
+        } else if (t.equals(DataTypes.FLOAT())) {
+            if (x instanceof Number) {
+                return ((Number) x).floatValue();
+            }
+            return Float.valueOf(x.toString());
+        } else if (t.equals(DataTypes.DOUBLE())) {
+            if (x instanceof Number) {
+                return ((Number) x).doubleValue();
+            }
+            return Double.valueOf(x.toString());
+        } else if (t.equals(DataTypes.STRING())) {
+            if (x instanceof String) {
+                return x;
+            }
+            return x.toString();
+        } else {
+            throw new RuntimeException("unsupported type: " + 
t.getClass().getName());
+        }
+    }
+
+    /** @return parameters for algorithm. */
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        if (null == this.params) {

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to