weibozhao commented on a change in pull request #24:
URL: https://github.com/apache/flink-ml/pull/24#discussion_r760919046



##########
File path: 
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/knn/KnnTest.java
##########
@@ -0,0 +1,273 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.ml.builder.Pipeline;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/** knn algorithm test. */
+public class KnnTest {
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainData;
+
+    List<Row> trainArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("f", "2.0 3.0", 1, 0, 1.47),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("m", "200.1 300.1", 1, 0, 1.5),
+                            Row.of("m", "200.2 300.2", 1, 0, 2.59),
+                            Row.of("m", "200.3 300.3", 1, 0, 2.55),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.53),
+                            Row.of("m", "200.4 300.4", 1, 0, 2.52),
+                            Row.of("m", "200.6 300.6", 1, 0, 2.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.56),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.51),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.52),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.53),
+                            Row.of("f", "2.3 3.2", 1, 0, 1.54),
+                            Row.of("c", "2.8 3.2", 3, 0, 1.6),
+                            Row.of("d", "300. 3.2", 5, 0, 1.5),
+                            Row.of("f", "2.2 3.2", 1, 0, 1.5),
+                            Row.of("e", "2.4 3.2", 2, 0, 1.3),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.4),
+                            Row.of("e", "2.5 3.2", 2, 0, 1.5),
+                            Row.of("f", "2.1 3.1", 1, 0, 1.6)));
+
+    List<Row> testArray =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of("e", "4.0 4.1", 2, 0, 1.5), Row.of("m", 
"300 42", 1, 0, 2.59)));
+
+    private Table testData;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        trainArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+        trainData = tEnv.fromDataStream(dataStream);
+
+        DataStream<Row> dataStreamStr =
+                env.fromCollection(
+                        testArray,
+                        new RowTypeInfo(
+                                new TypeInformation[] {
+                                    Types.STRING, Types.STRING, Types.INT, 
Types.INT, Types.DOUBLE
+                                },
+                                new String[] {"label", "vec", "f0", "f1", 
"f2"}));
+
+        testData = tEnv.fromDataStream(dataStreamStr);
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimator() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn Estimator. */
+    @Test
+    public void testKnnEstimatorWithFeatures() throws Exception {
+        Map<Param<?>, Object> params = new HashMap<>();
+        params.put(HasLabelCol.LABEL_COL, "label");
+        params.put(HasFeaturesCol.FEATURES_COL, "vec");
+        params.put(HasK.K, 4);
+        params.put(HasPredictionCol.PREDICTION_COL, "pred");
+        Knn knn = new Knn(params);
+
+        KnnModel knnModel = knn.fit(trainData);
+        Table result = knnModel.transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn as a pipeline stage. */
+    @Test
+    public void testKnnPipeline() throws Exception {
+        Knn knn =
+                new Knn()
+                        .setLabelCol("label")
+                        .setFeaturesCol("vec")
+                        .setK(4)
+                        .setPredictionCol("pred");
+
+        List<Stage<?>> stages = new ArrayList<>();
+        stages.add(knn);
+
+        Pipeline pipe = new Pipeline(stages);
+
+        Table result = pipe.fit(trainData).transform(testData)[0];
+
+        DataStream<Row> output = tEnv.toDataStream(result);
+
+        List<Row> rows = IteratorUtils.toList(output.executeAndCollect());
+        for (Row value : rows) {
+            String label = (String) value.getField(0);
+            String pred = (String) value.getField(5);
+            assert (label.equals(pred));
+        }
+    }
+
+    /** test knn model save. */
+    @Test
+    public void testKnnModelSave() throws Exception {
+        String knnPath = Files.createTempDirectory("").toString();
+        String modelPath = Files.createTempDirectory("").toString();
+        Knn knn =
+                new 
Knn().setLabelCol("f0").setFeaturesCol("vec").setK(4).setPredictionCol("pred");
+        knn.save(knnPath);
+        Knn cloneKnn = Knn.load(knnPath);
+        KnnModel knnModel = cloneKnn.fit(trainData);
+        knnModel.save(modelPath);
+        env.execute();
+    }
+
+    /** test knn model load and transform. */
+    @Test
+    public void testKnnModelLoad() throws Exception {

Review comment:
       done 

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
##########
@@ -0,0 +1,29 @@
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/**
+ * Param of the name of the label column in the input table.
+ *
+ * @param <T>

Review comment:
       done

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/KnnParams.java
##########
@@ -0,0 +1,10 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasK;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+
+/** knn parameters. */
+public interface KnnParams<T>
+        extends HasFeaturesCol<T>, HasLabelCol<T>, HasPredictionCol<T>, 
HasK<T> {}

Review comment:
       done

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/knn/Knn.java
##########
@@ -0,0 +1,152 @@
+package org.apache.flink.ml.classification.knn;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.VectorUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * KNN is to classify unlabeled observations by assigning them to the class of 
the most similar
+ * labeled examples.
+ */
+public class Knn implements Estimator<Knn, KnnModel>, KnnParams<Knn> {
+
+    protected Map<Param<?>, Object> params = new HashMap<>();
+
+    /** constructor. */
+    public Knn() {
+        ParamUtils.initializeMapWithDefaultValues(params, this);
+    }
+
+    /**
+     * constructor.
+     *
+     * @param params parameters for algorithm.
+     */
+    public Knn(Map<Param<?>, Object> params) {
+        this.params = params;
+    }
+
+    /**
+     * @param inputs a list of tables
+     * @return knn classification model.
+     */
+    @Override
+    public KnnModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        String[] colNames = schema.getColumnNames().toArray(new String[0]);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+        String labelCol = getLabelCol();
+        String vecCol = getFeaturesCol();
+
+        DataStream<Row> trainData =
+                input.map(
+                        (MapFunction<Row, Row>)
+                                value -> {
+                                    Object label = value.getField(labelCol);
+                                    DenseVector vec =
+                                            
VectorUtils.parse(value.getField(vecCol).toString());
+                                    return Row.of(label, vec);
+                                });
+
+        DataType idType = 
schema.getColumnDataTypes().get(findColIndex(colNames, labelCol));
+        DataStream<Row> model = buildModel(trainData, getParamMap(), idType);
+        KnnModel knnModel = new KnnModel(params);
+        knnModel.setModelData(tEnv.fromDataStream(model));
+        return knnModel;
+    }
+
+    /**
+     * build knn model.
+     *
+     * @param dataStream input data.
+     * @param params input parameters.
+     * @return stream format model.
+     */
+    private static DataStream<Row> buildModel(
+            DataStream<Row> dataStream, final Map<Param<?>, Object> params, 
final DataType idType) {
+        FastDistance fastDistance = new FastDistance();
+
+        return dataStream.transform(
+                "build index",
+                new RowTypeInfo(
+                        new TypeInformation[] {
+                            Types.STRING,
+                            
TypeInformation.of(idType.getLogicalType().getDefaultConversion())
+                        },
+                        new String[] {"DATA", "KNN_LABEL_TYPE"}),

Review comment:
       done




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to