weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r752889883



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,343 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+import static org.apache.flink.ml.classification.knn.KnnUtil.extractObject;
+import static org.apache.flink.ml.classification.knn.KnnUtil.findColTypes;
+import static org.apache.flink.ml.classification.knn.KnnUtil.pGson;
+import static 
org.apache.flink.ml.classification.knn.KnnUtil.resolvedSchema2Schema;
+
+/** Knn classification model fitted by KnnClassifier. */
+public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> {
+
+    private static final long serialVersionUID = 1303892137143865652L;
+
+    public static final String BROADCAST_STR = "broadcastModelKey";
+    private static final int FASTDISTANCE_TYPE_INDEX = 0;
+    private static final int DATA_INDEX = 1;
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    private Table[] modelData;
+
+    /** constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public KnnModel(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * Set model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * get model data.
+     *
+     * @return list of tables.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(BROADCAST_STR, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+        ResolvedSchema outputSchema =
+                getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+        DataType[] dataTypes = outputSchema.getColumnDataTypes().toArray(new 
DataType[0]);
+        TypeInformation<?>[] typeInformations = new 
TypeInformation[dataTypes.length];
+
+        for (int i = 0; i < dataTypes.length; ++i) {
+            typeInformations[i] = 
TypeInformation.of(dataTypes[i].getLogicalType().getClass());
+        }
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            new RowTypeInfo(
+                                    typeInformations,
+                                    outputSchema.getColumnNames().toArray(new 
String[0])),
+                            new PredictOperator(
+                                    new KnnRichFunction(
+                                            getParamMap(), 
inputs[0].getResolvedSchema())));
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+        return new Table[] {tEnv.fromDataStream(output, 
resolvedSchema2Schema(outputSchema))};
+    }
+
+    private static class KnnRichFunction extends RichMapFunction<Row, Row> {
+        private boolean firstEle = true;
+        private String[] reservedCols;
+        private String[] selectedCols;
+        private String vectorCol;
+        private DataType idType;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private Map<String, Object> meta;
+
+        public KnnRichFunction(Map<Param<?>, Object> params, ResolvedSchema 
dataSchema) {
+            reservedCols = (String[]) params.get(KnnParams.RESERVED_COLS);
+            reservedCols =
+                    (reservedCols == null)
+                            ? dataSchema.getColumnNames().toArray(new 
String[0])
+                            : reservedCols;
+            this.topN = (Integer) params.get(KnnParams.K);
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (firstEle) {
+                
loadModel(getRuntimeContext().getBroadcastVariable(BROADCAST_STR));
+                firstEle = false;
+            }
+            DenseVector vector;
+            if (null != selectedCols) {
+                vector = new DenseVector(new double[selectedCols.length]);
+                for (int i = 0; i < selectedCols.length; i++) {
+                    Preconditions.checkNotNull(
+                            row.getField(selectedCols[i]), "There is NULL in 
featureCols!");
+                    vector.set(i, ((Number) 
row.getField(selectedCols[i])).doubleValue());
+                }
+            } else {
+                vector = 
KnnUtil.parseDense(row.getField(vectorCol).toString());
+            }
+            String s = modelData.findNeighbor(vector, topN, 
null).toLowerCase();
+
+            Row ret = new Row(reservedCols.length + 1);
+            for (int i = 0; i < reservedCols.length; ++i) {
+                ret.setField(i, row.getField(reservedCols[i]));
+            }
+
+            Tuple2<Object, String> tuple2 = getResultFormat(extractObject(s, 
idType));
+            ret.setField(reservedCols.length, tuple2.f0);
+            return ret;
+        }
+
+        /**
+         * get output format of knn predict result.
+         *
+         * @param tuple initial result from knn predictor.
+         * @return output format result.
+         */
+        private Tuple2<Object, String> getResultFormat(Tuple2<List<Object>, 
List<Object>> tuple) {
+            double percent = 1.0 / tuple.f0.size();
+            Map<Object, Double> detail = new HashMap<>(0);
+
+            for (Object obj : tuple.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+
+            double max = 0.0;
+            Object prediction = null;
+
+            for (Map.Entry<Object, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+
+            return Tuple2.of(prediction, pGson.toJson(detail));
+        }
+
+        public void loadModel(List<Object> broadcastVar) {
+            List<FastDistanceMatrixData> dictData = new ArrayList<>();
+            for (Object obj : broadcastVar) {
+                Row row = (Row) obj;
+                if (row.getField(row.getArity() - 2) != null) {
+                    meta = pGson.fromJson((String) row.getField(row.getArity() 
- 2), HashMap.class);
+                }
+            }
+            for (Object obj : broadcastVar) {
+                Row row = (Row) obj;
+                if (row.getField(FASTDISTANCE_TYPE_INDEX) != null) {
+                    long type = (long) row.getField(FASTDISTANCE_TYPE_INDEX);
+                    if (type == 1L) {
+                        dictData.add(
+                                FastDistanceMatrixData.fromString(
+                                        (String) row.getField(DATA_INDEX)));
+                    }
+                }
+            }
+            if (meta.containsKey(KnnParams.FEATURE_COLS.name)) {
+                selectedCols =
+                        pGson.fromJson(
+                                (String) 
meta.get(KnnParams.FEATURE_COLS.name), String[].class);
+            } else {
+                vectorCol =
+                        pGson.fromJson((String) 
meta.get(KnnParams.VECTOR_COL.name), String.class);
+            }
+
+            modelData = new KnnModelData(dictData, new EuclideanDistance());
+            idType =
+                    LogicalTypeDataTypeConverter.toDataType(
+                            LogicalTypeParser.parse((String) 
this.meta.get("idType")));
+            modelData.setIdType(idType);
+        }
+    }
+
+    /**
+     * this operator use mapper to load the model data and do the prediction. 
if you want to write a
+     * prediction operator, you need implement a special mapper for this 
operator.
+     */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, RichMapFunction<Row, Row>>
+            implements OneInputStreamOperator<Row, Row> {
+
+        public PredictOperator(RichMapFunction<Row, Row> userFunction) {
+            super(userFunction);
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(userFunction.map(value)));
+        }
+    }
+
+    public ResolvedSchema getOutputSchema(
+            ResolvedSchema dataSchema, Map<Param<?>, Object> params, DataType 
idType) {
+        String[] reservedCols = (String[]) params.get(KnnParams.RESERVED_COLS);
+        reservedCols =
+                (reservedCols == null)
+                        ? dataSchema.getColumnNames().toArray(new String[0])
+                        : reservedCols;
+        DataType[] reservedTypes = findColTypes(dataSchema, reservedCols);
+        String[] resultCols = new String[] {(String) 
params.get(KnnParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+        return ResolvedSchema.physical(
+                ArrayUtils.addAll(reservedCols, resultCols),
+                ArrayUtils.addAll(reservedTypes, resultTypes));
+    }
+
+    /** @return parameters for algorithm. */
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        if (null == this.params) {
+            this.params = new HashMap<>(1);
+        }
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData[0]).getTableEnvironment();
+
+        String dataPath = ReadWriteUtils.getDataPath(path);
+        FileSink<Row> sink =
+                FileSink.forRowFormat(new Path(dataPath), new 
KnnModelData.ModelDataEncoder())
+                        .withRollingPolicy(OnCheckpointRollingPolicy.build())
+                        .withBucketAssigner(new BasePathBucketAssigner<>())
+                        .build();
+        tEnv.toDataStream(modelData[0]).sinkTo(sink);
+        HashMap<String, String> meta = new HashMap<>(1);
+        meta.put("idType", 
modelData[0].getResolvedSchema().getColumnDataTypes().get(3).toString());
+        ReadWriteUtils.saveMetadata(this, path, meta);
+    }
+
+    public void load(StreamExecutionEnvironment env, String path) throws 
IOException {

Review comment:
       this method should be static. I accept.
    I will change it later.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to