weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r755688178



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,255 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    private static final long serialVersionUID = 5292477422193301398L;
+    private static final int ROW_SIZE = 2;
+    private static final int FASTDISTANCE_TYPE_INDEX = 0;
+    private static final int DATA_INDEX = 1;
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public Knn(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * @param inputs a list of tables
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+        String[] targetCols = getFeatureCols();
+        final int[] featureIndices;
+        if (targetCols == null) {
+            featureIndices = new int[colNames.length];
+            for (int i = 0; i < colNames.length; i++) {
+                featureIndices[i] = i;
+            }
+        } else {
+            featureIndices = new int[targetCols.length];
+            for (int i = 0; i < featureIndices.length; i++) {
+                featureIndices[i] = findColIndex(colNames, targetCols[i]);
+            }
+        }
+        String labelCol = getLabelCol();
+        final int labelIdx = findColIndex(colNames, labelCol);
+        final int vecIdx =
+                getVectorCol() != null
+                        ? findColIndex(
+                                inputs[0]
+                                        .getResolvedSchema()
+                                        .getColumnNames()
+                                        .toArray(new String[0]),
+                                getVectorCol())
+                        : -1;
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = value.getField(labelIdx);

Review comment:
       done

##########
File path: 
flink-ml-api/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
##########
@@ -43,8 +46,12 @@
 
 /** Utility methods for reading and writing stages. */
 public class ReadWriteUtils {
-    public static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
-
+    public static Gson OBJECT_MAPPER =

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to