zhipeng93 commented on a change in pull request #32:
URL: https://github.com/apache/flink-ml/pull/32#discussion_r748081703



##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * Naive Bayes classifier is a simple probability classification algorithm 
using
+ * Bayes theorem based on independent assumption. It is an independent feature 
model.
+ * The input feature can be continual or categorical.
+ */
+public class NaiveBayes implements Estimator<NaiveBayes, NaiveBayesModel>,
+        NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        String[] featureColNames = getFeatureCols();
+        String labelColName = getLabelCol();
+        String predictionCol = getPredictionCol();
+        double smoothing = getSmoothing();
+
+        Preconditions.checkNotNull(inputs, "input table list should not be 
null");
+        Preconditions.checkArgument(inputs.length == 1, "input table list 
should contain only one argument");
+        Preconditions.checkArgument(
+                new HashSet<>(Arrays.asList(featureColNames)).size() == 
featureColNames.length,
+                "feature columns should not duplicate");
+        Preconditions.checkNotNull(labelColName, "label column should be set");
+
+        StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+        DataStream<NaiveBayesModelData> naiveBayesModel = input
+                .flatMap(new FlattenFunction(
+                        featureColNames,
+                        labelColName
+                ))
+                .keyBy((KeySelector<Tuple4<Object, Integer, Object, Double>, 
Object>) value -> new Tuple3<>(value.f0, value.f1, value.f2))
+                .window(EndOfStreamWindows.get())
+                .reduce((ReduceFunction<Tuple4<Object, Integer, Object, 
Double>>) (t0, t1) -> {t0.f3 += t1.f3; return t0; })
+                .keyBy((KeySelector<Tuple4<Object, Integer, Object, Double>, 
Object>) value -> new Tuple2<>(value.f0, value.f1))
+                .window(EndOfStreamWindows.get())
+                .aggregate(new ValueMapFunction())
+                .keyBy((KeySelector<Tuple4<Object, Integer, Map<Object, 
Double>, Double>, Object>) value -> value.f0)
+                .window(EndOfStreamWindows.get())
+                .aggregate(new MapArrayFunction(featureColNames.length))
+                .windowAll(EndOfStreamWindows.get())
+                .apply(new GenerateModelFunction(
+                        smoothing,
+                        featureColNames));
+
+        NaiveBayesModel model = new NaiveBayesModel()
+                .setPredictionCol(predictionCol)
+                .setFeatureCols(featureColNames);
+        model.setModelData(
+                tEnv.fromDataStream(naiveBayesModel)
+        );
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static NaiveBayes load(String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Function to convert each column into tuples of label, feature column 
index, feature value, weight.
+     */
+    private static class FlattenFunction implements FlatMapFunction<Row, 
Tuple4<Object, Integer, Object, Double>> {
+        private final String[] featureColNames;
+        private final String labelColName;
+        private final int featureSize;
+
+        private FlattenFunction(String[] featureColNames, String labelColName) 
{
+            this.labelColName = labelColName;
+            this.featureColNames = featureColNames;
+            this.featureSize = featureColNames.length;
+        }
+
+        @Override
+        public void flatMap(Row row, Collector<Tuple4<Object, Integer, Object, 
Double>> collector) {
+            Object label = row.getField(labelColName);
+            if (label == null) {
+                return;
+            }
+
+            for (int i = 0; i < featureSize; i++) {
+                Object feature = row.getField(featureColNames[i]);
+                if (feature == null) {
+                    continue;
+                }
+                collector.collect(new Tuple4<>(label, i, feature, 1.0));

Review comment:
       should we support weighted naive bayes or just remove 1.0 here?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * Naive Bayes classifier is a simple probability classification algorithm 
using
+ * Bayes theorem based on independent assumption. It is an independent feature 
model.
+ * The input feature can be continual or categorical.
+ */
+public class NaiveBayes implements Estimator<NaiveBayes, NaiveBayesModel>,
+        NaiveBayesParams<NaiveBayes> {

Review comment:
       Shall we use different param classes for `NaiveBayes` and the 
`NaiveBayesModel`?
   
   it may be confusing to users when call `new 
NaiveBayesModel().setSmoothing(...)` since `smoothing` is useless during the 
prediction.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/EndOfStreamWindows.java
##########
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.windowing.assigners.WindowAssigner;
+import org.apache.flink.streaming.api.windowing.triggers.EventTimeTrigger;
+import org.apache.flink.streaming.api.windowing.triggers.Trigger;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * A {@link WindowAssigner} that windows elements into windows based on the 
timestamp of the
+ * elements. Windows cannot overlap.
+ *
+ * <p>For example, in order to window into windows of 1 minute:
+ *
+ * <pre>{@code
+ * DataStream<Tuple2<String, Integer>> in = ...;
+ * KeyedStream<Tuple2<String, Integer>, String> keyed = in.keyBy(...);
+ * WindowedStream<Tuple2<String, Integer>, String, TimeWindow> windowed =
+ *   keyed.window(TumblingEventTimeWindows.of(Time.minutes(1)));
+ * }</pre>
+ */
+@PublicEvolving
+public class EndOfStreamWindows extends WindowAssigner<Object, TimeWindow> {
+
+    private static final EndOfStreamWindows INSTANCE = new 
EndOfStreamWindows();
+
+    private EndOfStreamWindows() {}
+
+    public static EndOfStreamWindows get() {
+        return INSTANCE;
+    }
+
+    @Override
+    public Collection<TimeWindow> assignWindows(
+            Object element, long timestamp, WindowAssignerContext context) {
+        return Collections.singletonList(new TimeWindow(Long.MIN_VALUE, 
Long.MAX_VALUE));

Review comment:
       why do you create a new TimeWindow on each call?

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Map;
+
+/**
+ * The model data of {@link NaiveBayesModel}.
+ */
+public class NaiveBayesModelData implements Serializable {
+    private static final long serialVersionUID = 3919917903722286395L;
+    public final String[] featureNames;
+    public final Map<Object, Double>[][] theta;
+    public final double[] piArray;
+    public final Object[] label;
+
+    public static final Schema SCHEMA =
+            Schema.newBuilder()
+                    .column("f0", DataTypes.of(NaiveBayesModelData.class))
+                    .build();
+
+    public NaiveBayesModelData(String[] featureNames, Map<Object, Double>[][] 
theta, double[] piArray, Object[] label) {
+        this.featureNames = featureNames;
+        this.theta = theta;
+        this.piArray = piArray;
+        this.label = label;
+    }
+
+    /** Encoder for the KMeans model data. */

Review comment:
       update the comment here.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
##########
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.util.Map;
+
+/**
+ * The model data of {@link NaiveBayesModel}.
+ */
+public class NaiveBayesModelData implements Serializable {
+    private static final long serialVersionUID = 3919917903722286395L;
+    public final String[] featureNames;
+    public final Map<Object, Double>[][] theta;
+    public final double[] piArray;
+    public final Object[] label;
+
+    public static final Schema SCHEMA =
+            Schema.newBuilder()
+                    .column("f0", DataTypes.of(NaiveBayesModelData.class))
+                    .build();
+
+    public NaiveBayesModelData(String[] featureNames, Map<Object, Double>[][] 
theta, double[] piArray, Object[] label) {
+        this.featureNames = featureNames;
+        this.theta = theta;
+        this.piArray = piArray;
+        this.label = label;
+    }
+
+    /** Encoder for the KMeans model data. */
+    public static class ModelDataEncoder implements 
Encoder<NaiveBayesModelData> {
+        @Override
+        public void encode(NaiveBayesModelData modelData, OutputStream 
outputStream) {
+            Kryo kryo = new Kryo();
+            Output output = new Output(outputStream);
+            kryo.writeObject(output, modelData);
+            output.flush();
+        }
+    }
+
+    /** Decoder for the KMeans model data. */

Review comment:
       update the comment here.

##########
File path: 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
##########
@@ -0,0 +1,333 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.core.Estimator;
+import org.apache.flink.ml.common.datastream.EndOfStreamWindows;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.windowing.windows.TimeWindow;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * Naive Bayes classifier is a simple probability classification algorithm 
using
+ * Bayes theorem based on independent assumption. It is an independent feature 
model.
+ * The input feature can be continual or categorical.
+ */
+public class NaiveBayes implements Estimator<NaiveBayes, NaiveBayesModel>,
+        NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        String[] featureColNames = getFeatureCols();
+        String labelColName = getLabelCol();
+        String predictionCol = getPredictionCol();
+        double smoothing = getSmoothing();
+
+        Preconditions.checkNotNull(inputs, "input table list should not be 
null");
+        Preconditions.checkArgument(inputs.length == 1, "input table list 
should contain only one argument");
+        Preconditions.checkArgument(
+                new HashSet<>(Arrays.asList(featureColNames)).size() == 
featureColNames.length,
+                "feature columns should not duplicate");
+        Preconditions.checkNotNull(labelColName, "label column should be set");
+
+        StreamTableEnvironment tEnv = (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+        DataStream<NaiveBayesModelData> naiveBayesModel = input
+                .flatMap(new FlattenFunction(
+                        featureColNames,
+                        labelColName
+                ))
+                .keyBy((KeySelector<Tuple4<Object, Integer, Object, Double>, 
Object>) value -> new Tuple3<>(value.f0, value.f1, value.f2))
+                .window(EndOfStreamWindows.get())
+                .reduce((ReduceFunction<Tuple4<Object, Integer, Object, 
Double>>) (t0, t1) -> {t0.f3 += t1.f3; return t0; })
+                .keyBy((KeySelector<Tuple4<Object, Integer, Object, Double>, 
Object>) value -> new Tuple2<>(value.f0, value.f1))
+                .window(EndOfStreamWindows.get())
+                .aggregate(new ValueMapFunction())
+                .keyBy((KeySelector<Tuple4<Object, Integer, Map<Object, 
Double>, Double>, Object>) value -> value.f0)
+                .window(EndOfStreamWindows.get())
+                .aggregate(new MapArrayFunction(featureColNames.length))
+                .windowAll(EndOfStreamWindows.get())

Review comment:
       When we use `windowAll`, we are actually changing paralllelism to one 
and using one task to process all the elements. 
   
   Are there any efficiency issues?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to