yunfengzhou-hub commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r752833908



##########
File path: flink-ml-lib/pom.xml
##########
@@ -106,6 +112,11 @@ under the License.
       <type>jar</type>
       <scope>test</scope>
     </dependency>
+      <dependency>
+          <groupId>com.google.code.gson</groupId>
+          <artifactId>gson</artifactId>
+          <version>2.8.6</version>
+      </dependency>

Review comment:
       `ReadWriteUtils` has provided methods to generate json. Shall we reuse 
the existing method and avoid adding new dependencies?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnUtil.java
##########
@@ -0,0 +1,428 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.core.type.TypeReference;
+
+import com.google.gson.Gson;
+import com.google.gson.GsonBuilder;
+import org.apache.commons.lang3.StringUtils;
+import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl;
+
+import java.lang.reflect.Type;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.PriorityQueue;
+import java.util.TreeMap;
+
+/** Utility class for the knn algorithm. */
+public class KnnUtil {

Review comment:
       Methods in this class could be placed in classes like `TableUtils` and 
`VectorUtils`.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/EuclideanDistance.java
##########
@@ -0,0 +1,272 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.Iterables;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
+import static 
org.apache.flink.ml.classification.knn.KnnUtil.appendVectorToMatrix;
+
+/**
+ * Euclidean distance is the "ordinary" straight-line distance between two 
points in Euclidean
+ * space.
+ *
+ * <p>https://en.wikipedia.org/wiki/Euclidean_distance
+ *
+ * <p>Given two vectors a and b, Euclidean Distance = ||a - b||, where ||*|| 
means the L2 norm of
+ * the vector.
+ */
+public class EuclideanDistance implements Serializable {

Review comment:
       The design of an independent `EuclideanDistance` is different from both 
Alink(which implements `FastDistance`) and spark(which does not have such 
class). Flink ML also has `EuclideanDistanceMeasure` which might help achieving 
this functionality. It might be better to introduce distance classes while 
following existing or discussed conventions.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModelData.java
##########
@@ -0,0 +1,223 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.collect.ImmutableMap;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.List;
+import java.util.PriorityQueue;
+
+import static org.apache.flink.ml.classification.knn.KnnUtil.castTo;
+import static org.apache.flink.ml.classification.knn.KnnUtil.serializeResult;
+import static org.apache.flink.ml.classification.knn.KnnUtil.updateQueue;
+
+/** knn model data, which will be used to calculate the distances between 
nodes. */
+public class KnnModelData implements Serializable, Cloneable {
+    private static final long serialVersionUID = -2940551481683238630L;
+    private final List<FastDistanceMatrixData> dictData;
+    private final EuclideanDistance fastDistance;
+    protected Comparator<? super Tuple2<Double, Object>> comparator;
+    private DataType idType;
+
+    /**
+     * constructor.
+     *
+     * @param list BaseFastDistanceData list.
+     * @param fastDistance used to accelerate the speed of calculating 
distance.
+     */
+    public KnnModelData(List<FastDistanceMatrixData> list, EuclideanDistance 
fastDistance) {
+        this.dictData = list;
+        this.fastDistance = fastDistance;
+        comparator = Comparator.comparingDouble(o -> -o.f0);
+    }
+
+    /**
+     * set id type.
+     *
+     * @param idType id type.
+     */
+    public void setIdType(DataType idType) {
+        this.idType = idType;
+    }
+
+    /**
+     * find the nearest topN neighbors from whole nodes.
+     *
+     * @param input input node.
+     * @param topN top N.
+     * @param radius the parameter to describe the range to find neighbors.
+     * @return
+     */
+    public String findNeighbor(Object input, Integer topN, Double radius) {

Review comment:
       I personally think it could be better if we do not put calculation 
logics in model data classes.

##########
File path: flink-ml-api/src/main/java/org/apache/flink/ml/linalg/Vector.java
##########
@@ -29,6 +29,10 @@
     /** Gets the value of the ith element. */
     double get(int i);
 
+
+    /** set the value of the ith element. */
+    void set(int i, double val);

Review comment:
       I am not sure we need a `set` method for `Vector`. From my perspective 
it might be better to make `Vector` immutable and recreate a new `Vector` if we 
need to change it.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,220 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.apache.flink.ml.classification.knn.KnnUtil.findColIndex;
+import static org.apache.flink.ml.classification.knn.KnnUtil.findColIndices;
+import static org.apache.flink.ml.classification.knn.KnnUtil.merge;
+import static org.apache.flink.ml.classification.knn.KnnUtil.pGson;
+
+/**
+ * KNN classifier is to classify unlabeled observations by assigning them to 
the class of the most

Review comment:
       I suppose the term `classifier` has been removed according to previous 
reviews. Comments  like this should be updated.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnModel.java
##########
@@ -0,0 +1,343 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.table.types.logical.utils.LogicalTypeParser;
+import org.apache.flink.table.types.utils.LogicalTypeDataTypeConverter;
+import org.apache.flink.types.Row;
+
+import org.apache.flink.shaded.curator4.com.google.common.base.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+import static org.apache.flink.ml.classification.knn.KnnUtil.extractObject;
+import static org.apache.flink.ml.classification.knn.KnnUtil.findColTypes;
+import static org.apache.flink.ml.classification.knn.KnnUtil.pGson;
+import static 
org.apache.flink.ml.classification.knn.KnnUtil.resolvedSchema2Schema;
+
+/** Knn classification model fitted by KnnClassifier. */
+public class KnnModel implements Model<KnnModel>, KnnParams<KnnModel> {
+
+    private static final long serialVersionUID = 1303892137143865652L;
+
+    public static final String BROADCAST_STR = "broadcastModelKey";
+    private static final int FASTDISTANCE_TYPE_INDEX = 0;
+    private static final int DATA_INDEX = 1;
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    private Table[] modelData;
+
+    /** constructor. */
+    public KnnModel() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public KnnModel(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * Set model data for knn prediction.
+     *
+     * @param modelData knn model.
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel setModelData(Table... modelData) {
+        this.modelData = modelData;
+        return this;
+    }
+
+    /**
+     * get model data.
+     *
+     * @return list of tables.
+     */
+    @Override
+    public Table[] getModelData() {
+        return modelData;
+    }
+
+    /**
+     * @param inputs a list of tables.
+     * @return result.
+     */
+    @Override
+    public Table[] transform(Table... inputs) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        DataStream<Row> model = tEnv.toDataStream(modelData[0]);
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>(1);
+        broadcastMap.put(BROADCAST_STR, model);
+        ResolvedSchema modelSchema = modelData[0].getResolvedSchema();
+        DataType idType =
+                
modelSchema.getColumnDataTypes().get(modelSchema.getColumnNames().size() - 1);
+
+        ResolvedSchema outputSchema =
+                getOutputSchema(inputs[0].getResolvedSchema(), getParamMap(), 
idType);
+
+        DataType[] dataTypes = outputSchema.getColumnDataTypes().toArray(new 
DataType[0]);
+        TypeInformation<?>[] typeInformations = new 
TypeInformation[dataTypes.length];
+
+        for (int i = 0; i < dataTypes.length; ++i) {
+            typeInformations[i] = 
TypeInformation.of(dataTypes[i].getLogicalType().getClass());
+        }
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.transform(
+                            "mapFunc",
+                            new RowTypeInfo(
+                                    typeInformations,
+                                    outputSchema.getColumnNames().toArray(new 
String[0])),
+                            new PredictOperator(
+                                    new KnnRichFunction(
+                                            getParamMap(), 
inputs[0].getResolvedSchema())));
+                };
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input), broadcastMap, 
function);
+        return new Table[] {tEnv.fromDataStream(output, 
resolvedSchema2Schema(outputSchema))};
+    }
+
+    private static class KnnRichFunction extends RichMapFunction<Row, Row> {
+        private boolean firstEle = true;
+        private String[] reservedCols;
+        private String[] selectedCols;
+        private String vectorCol;
+        private DataType idType;
+        private transient KnnModelData modelData;
+        private final Integer topN;
+        private Map<String, Object> meta;
+
+        public KnnRichFunction(Map<Param<?>, Object> params, ResolvedSchema 
dataSchema) {
+            reservedCols = (String[]) params.get(KnnParams.RESERVED_COLS);
+            reservedCols =
+                    (reservedCols == null)
+                            ? dataSchema.getColumnNames().toArray(new 
String[0])
+                            : reservedCols;
+            this.topN = (Integer) params.get(KnnParams.K);
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (firstEle) {
+                
loadModel(getRuntimeContext().getBroadcastVariable(BROADCAST_STR));
+                firstEle = false;
+            }
+            DenseVector vector;
+            if (null != selectedCols) {
+                vector = new DenseVector(new double[selectedCols.length]);
+                for (int i = 0; i < selectedCols.length; i++) {
+                    Preconditions.checkNotNull(
+                            row.getField(selectedCols[i]), "There is NULL in 
featureCols!");
+                    vector.set(i, ((Number) 
row.getField(selectedCols[i])).doubleValue());
+                }
+            } else {
+                vector = 
KnnUtil.parseDense(row.getField(vectorCol).toString());
+            }
+            String s = modelData.findNeighbor(vector, topN, 
null).toLowerCase();
+
+            Row ret = new Row(reservedCols.length + 1);
+            for (int i = 0; i < reservedCols.length; ++i) {
+                ret.setField(i, row.getField(reservedCols[i]));
+            }
+
+            Tuple2<Object, String> tuple2 = getResultFormat(extractObject(s, 
idType));
+            ret.setField(reservedCols.length, tuple2.f0);
+            return ret;
+        }
+
+        /**
+         * get output format of knn predict result.
+         *
+         * @param tuple initial result from knn predictor.
+         * @return output format result.
+         */
+        private Tuple2<Object, String> getResultFormat(Tuple2<List<Object>, 
List<Object>> tuple) {
+            double percent = 1.0 / tuple.f0.size();
+            Map<Object, Double> detail = new HashMap<>(0);
+
+            for (Object obj : tuple.f0) {
+                detail.merge(obj, percent, Double::sum);
+            }
+
+            double max = 0.0;
+            Object prediction = null;
+
+            for (Map.Entry<Object, Double> entry : detail.entrySet()) {
+                if (entry.getValue() > max) {
+                    max = entry.getValue();
+                    prediction = entry.getKey();
+                }
+            }
+
+            return Tuple2.of(prediction, pGson.toJson(detail));
+        }
+
+        public void loadModel(List<Object> broadcastVar) {
+            List<FastDistanceMatrixData> dictData = new ArrayList<>();
+            for (Object obj : broadcastVar) {
+                Row row = (Row) obj;
+                if (row.getField(row.getArity() - 2) != null) {
+                    meta = pGson.fromJson((String) row.getField(row.getArity() 
- 2), HashMap.class);
+                }
+            }
+            for (Object obj : broadcastVar) {
+                Row row = (Row) obj;
+                if (row.getField(FASTDISTANCE_TYPE_INDEX) != null) {
+                    long type = (long) row.getField(FASTDISTANCE_TYPE_INDEX);
+                    if (type == 1L) {
+                        dictData.add(
+                                FastDistanceMatrixData.fromString(
+                                        (String) row.getField(DATA_INDEX)));
+                    }
+                }
+            }
+            if (meta.containsKey(KnnParams.FEATURE_COLS.name)) {
+                selectedCols =
+                        pGson.fromJson(
+                                (String) 
meta.get(KnnParams.FEATURE_COLS.name), String[].class);
+            } else {
+                vectorCol =
+                        pGson.fromJson((String) 
meta.get(KnnParams.VECTOR_COL.name), String.class);
+            }
+
+            modelData = new KnnModelData(dictData, new EuclideanDistance());
+            idType =
+                    LogicalTypeDataTypeConverter.toDataType(
+                            LogicalTypeParser.parse((String) 
this.meta.get("idType")));
+            modelData.setIdType(idType);
+        }
+    }
+
+    /**
+     * this operator use mapper to load the model data and do the prediction. 
if you want to write a
+     * prediction operator, you need implement a special mapper for this 
operator.
+     */
+    private static class PredictOperator
+            extends AbstractUdfStreamOperator<Row, RichMapFunction<Row, Row>>
+            implements OneInputStreamOperator<Row, Row> {
+
+        public PredictOperator(RichMapFunction<Row, Row> userFunction) {
+            super(userFunction);
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row value = streamRecord.getValue();
+            output.collect(new StreamRecord<>(userFunction.map(value)));
+        }
+    }
+
+    public ResolvedSchema getOutputSchema(
+            ResolvedSchema dataSchema, Map<Param<?>, Object> params, DataType 
idType) {
+        String[] reservedCols = (String[]) params.get(KnnParams.RESERVED_COLS);
+        reservedCols =
+                (reservedCols == null)
+                        ? dataSchema.getColumnNames().toArray(new String[0])
+                        : reservedCols;
+        DataType[] reservedTypes = findColTypes(dataSchema, reservedCols);
+        String[] resultCols = new String[] {(String) 
params.get(KnnParams.PREDICTION_COL)};
+        DataType[] resultTypes = new DataType[] {idType};
+        return ResolvedSchema.physical(
+                ArrayUtils.addAll(reservedCols, resultCols),
+                ArrayUtils.addAll(reservedTypes, resultTypes));
+    }
+
+    /** @return parameters for algorithm. */
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        if (null == this.params) {
+            this.params = new HashMap<>(1);
+        }
+        return this.params;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelData[0]).getTableEnvironment();
+
+        String dataPath = ReadWriteUtils.getDataPath(path);
+        FileSink<Row> sink =
+                FileSink.forRowFormat(new Path(dataPath), new 
KnnModelData.ModelDataEncoder())
+                        .withRollingPolicy(OnCheckpointRollingPolicy.build())
+                        .withBucketAssigner(new BasePathBucketAssigner<>())
+                        .build();
+        tEnv.toDataStream(modelData[0]).sinkTo(sink);
+        HashMap<String, String> meta = new HashMap<>(1);
+        meta.put("idType", 
modelData[0].getResolvedSchema().getColumnDataTypes().get(3).toString());
+        ReadWriteUtils.saveMetadata(this, path, meta);
+    }
+
+    public void load(StreamExecutionEnvironment env, String path) throws 
IOException {

Review comment:
       this method should be static.
   
   And the corresponding `Estimator` should also have a static `load` method.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/DenseMatrix.java
##########
@@ -0,0 +1,161 @@
+package org.apache.flink.ml.classification.knn;
+
+import java.io.Serializable;
+
+/**
+ * Knn DenseMatrix stores dense matrix data and provides some methods to 
operate on the matrix it
+ * represents. This data structure helps knn to accelerate distance 
calculation.
+ */
+public class DenseMatrix implements Serializable {

Review comment:
       Instead of putting `DenseMatrix`, distance classes in this package, I 
think it could be better to place them in common packages like `distance` or 
`linalg`.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to