lindong28 commented on a change in pull request #37: URL: https://github.com/apache/flink-ml/pull/37#discussion_r762376047
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java ########## @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests OneHotEncoder and OneHotEncoderModel. */ +public class OneHotEncoderTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Double, Vector>[] expectedOutput; + private OneHotEncoder estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0)) Review comment: This line seems a bit too crowded. Putting all elements into the same line might not be very readable when there are many elements. And ideally we want to use consistent code style to construct input tables. Could we construct the list of input elements separately before constructing the table? Maybe follow `KMeansTest` for example. Same for `NaiveBayesTest`. ########## File path: flink-ml-core/src/test/java/org/apache/flink/ml/linalg/VectorsTest.java ########## @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.junit.Before; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests the behavior of Vectors. */ +public class VectorsTest { + int n; + int[] indices; + double[] values; + + @Before + public void before() { Review comment: Since we only have one test that needs these variables, the typical practice is to start simple and create these variables in the test directly. And if we do need to share these variables across multiple tests in the future, since these variables are pretty simple, it is probably simpler to initialize these variables when they are declared, instead of using `before()`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java ########## @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Vector; +import java.util.function.Function; + +/** + * A Model which encodes data into one-hot format using the model data computed by {@link + * OneHotEncoder}. + */ +public class OneHotEncoderModel + implements Model<OneHotEncoderModel>, OneHotEncoderParams<OneHotEncoderModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelTable; + + public OneHotEncoderModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + final String[] inputCols = getInputCols(); + final String[] outputCols = getOutputCols(); + final boolean dropLast = getDropLast(); + final String broadcastModelKey = "OneHotModelStream"; + + Preconditions.checkArgument(getHandleInvalid().equals("ERROR")); + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(inputCols.length == outputCols.length); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + Collections.nCopies( + outputCols.length, + ExternalTypeInfo.of(Vector.class)) + .toArray(new TypeInformation[0])), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelTable).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Tuple2<Integer, Integer>> modelStream = + OneHotEncoderModelData.getModelDataStream(tEnv, modelTable); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(broadcastModelKey, modelStream); + + Function<List<DataStream<?>>, DataStream<Row>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.transform( + PredictLabelOperator.class.getSimpleName(), Review comment: nits: could we just use "PredictLabelOperator" for simplicity here? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModelData.java ########## @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import java.io.IOException; +import java.io.OutputStream; + +/** Provides classes to save/load model data. */ +public class OneHotEncoderModelData { + /** Converts the provided modelData Datastream into corresponding Table. */ + public static Table getModelDataTable( + StreamTableEnvironment tEnv, DataStream<Tuple2<Integer, Integer>> stream) { + return tEnv.fromDataStream(stream); + } + + /** Converts the provided modelData Table into corresponding Datastream. */ + public static DataStream<Tuple2<Integer, Integer>> getModelDataStream( + StreamTableEnvironment tEnv, Table table) { + return tEnv.toDataStream(table) + .map( + new MapFunction<Row, Tuple2<Integer, Integer>>() { + @Override + public Tuple2<Integer, Integer> map(Row row) throws Exception { + return new Tuple2<>( + (int) row.getField("f0"), (int) row.getField("f1")); + } + }); + } + + /** Encoder for the OneHotEncoder model data. */ + public static class ModelDataEncoder implements Encoder<Tuple2<Integer, Integer>> { + @Override + public void encode(Tuple2<Integer, Integer> modeldata, OutputStream outputStream) + throws IOException { + Kryo kryo = new Kryo(); Review comment: Since the model data are list of `Tuple2<Integer, Integer>`, could we simplify the encoder/decoder by not using `Kryo`? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java ########## @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; +import java.util.Objects; + +/** A sparse vector of double values. */ +@TypeInfo(SparseVectorTypeInfoFactory.class) +public class SparseVector implements Vector { + public final int n; + public final int[] indices; + public final double[] values; + + public SparseVector() { Review comment: I guess this empty constructor was added to enable serialization. But `DenseVector` could be serialized without having an empty constructor. Could you remove this? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java ########## @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; +import java.util.Objects; + +/** A sparse vector of double values. */ +@TypeInfo(SparseVectorTypeInfoFactory.class) +public class SparseVector implements Vector { + public final int n; + public final int[] indices; + public final double[] values; + + public SparseVector() { + this(-1); + } + + public SparseVector(int n) { + this(n, new int[0], new double[0]); + } + + public SparseVector(int n, int index, double value) { Review comment: Can we start simple and keep only the constructor `SparseVector(int n, int[] indices, double[] values)`? IMO if there is only one index and one value, it is not too much work for caller to construct `int[] indices` and `double[] values` in one line. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java ########## @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests OneHotEncoder and OneHotEncoderModel. */ +public class OneHotEncoderTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Double, Vector>[] expectedOutput; + private OneHotEncoder estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("input"); + + predictTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("input"); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(2, 0, 1.0)); + put(1.0, Vectors.sparse(2, 1, 1.0)); + put(2.0, Vectors.sparse(2)); + } + } + }; + + estimator = new OneHotEncoder().setInputCols("input").setOutputCols("output"); + } + + /** + * Executes a given table and collect its results. Results are returned as a map array. Each + * element in the array is a map corresponding to a input column whose key is the original value + * in the input column, value is the one-hot encoding result of that value. + * + * @param table A table to be executed and to have its result collected + * @param inputCols Name of the input columns + * @param outputCols Name of the output columns containing one-hot encoding result + * @return A map containing the collected results + */ + private static Map<Double, Vector>[] executeAndCollect( + Table table, String[] inputCols, String[] outputCols) { + Map<Double, Vector>[] maps = new HashMap[inputCols.length]; + for (int i = 0; i < inputCols.length; i++) { + maps[i] = new HashMap<>(); + } + for (CloseableIterator<Row> it = table.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + for (int i = 0; i < inputCols.length; i++) { + maps[i].put( + ((Number) row.getField(inputCols[i])).doubleValue(), + (Vector) row.getField(outputCols[i])); + } + } + return maps; + } + + @Test + public void testParam() { + OneHotEncoder estimator = new OneHotEncoder(); + + assertTrue(estimator.getDropLast()); + + estimator.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); + + assertArrayEquals(new String[] {"test_input"}, estimator.getInputCols()); + assertArrayEquals(new String[] {"test_output"}, estimator.getOutputCols()); + assertFalse(estimator.getDropLast()); + + OneHotEncoderModel model = new OneHotEncoderModel(); + + assertTrue(model.getDropLast()); + + model.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); + + assertArrayEquals(new String[] {"test_input"}, model.getInputCols()); + assertArrayEquals(new String[] {"test_output"}, model.getOutputCols()); + assertFalse(model.getDropLast()); + } + + @Test + public void testFitAndPredict() { + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } + + @Test + public void testDropLast() { + estimator.setDropLast(false); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(3, 0, 1.0)); + put(1.0, Vectors.sparse(3, 1, 1.0)); + put(2.0, Vectors.sparse(3, 2, 1.0)); + } + } + }; + + OneHotEncoderModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + Map<Double, Vector>[] actualOutput = + executeAndCollect(outputTable, model.getInputCols(), model.getOutputCols()); + assertArrayEquals(expectedOutput, actualOutput); + } + + @Test + public void testInputDataType() { + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") Review comment: Could you explain why we need `columnByMetadata` and `watermark`? In other words, what would go wrong if we remove them? If we need them, maybe we need to consistently apply them across all tests (e.g. KMeansTest). ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/param/ParamValidators.java ########## @@ -95,4 +95,9 @@ public boolean validate(T value) { } }; } + + // Check if the length of the parameter value is greater than lowerBound. + public static <T> ParamValidator<T> lenGt(int lowerBound) { Review comment: Is there any use-case where lowerBound is not 0? If no, how about we use `nonEmptyArray()` for now? And do we need to check `value != null` in this method? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java ########## @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; +import java.util.Objects; + +/** A sparse vector of double values. */ +@TypeInfo(SparseVectorTypeInfoFactory.class) +public class SparseVector implements Vector { + public final int n; + public final int[] indices; + public final double[] values; + + public SparseVector() { + this(-1); + } + + public SparseVector(int n) { + this(n, new int[0], new double[0]); + } + + public SparseVector(int n, int index, double value) { + this(n, new int[] {index}, new double[] {value}); + } + + public SparseVector(int n, int[] indices, double[] values) { + this.n = n; + this.indices = indices; + this.values = values; + checkSizeAndIndicesRange(); + sortIndices(); + } + + @Override + public int size() { + return n; + } + + @Override + public double get(int i) { + int pos = Arrays.binarySearch(indices, i); + if (pos >= 0) { + return values[pos]; + } + return 0.; + } + + @Override + public double[] toArray() { + double[] result = new double[n]; + for (int i = 0; i < indices.length; i++) { + result[indices[i]] = values[i]; + } + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SparseVector that = (SparseVector) o; + return n == that.n + && Arrays.equals(indices, that.indices) + && Arrays.equals(values, that.values); + } + + @Override + public int hashCode() { + int result = Objects.hash(n); + result = 31 * result + Arrays.hashCode(indices); + result = 31 * result + Arrays.hashCode(values); + return result; + } + + /** + * Check whether the indices array and values array are of the same size, and whether vector + * indices are in valid range. + */ + private void checkSizeAndIndicesRange() { Review comment: Could we also check whether there are duplicate indices? Spark does this check. And this can be done after the indices are sorted. ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vectors.java ########## @@ -25,4 +25,19 @@ public static DenseVector dense(double... values) { return new DenseVector(values); } + + /** Creates a sparse vector from its values. */ + public static SparseVector sparse(int size) { Review comment: The Java doc is not consistent with this method's signature. I am a bit worried that we will add too many "convenience" APIs in `Vectors`. Could we follow Spark's style and use only `sparse(int size, int[] indices, double[] values)` to create `SparseVector` for now? IMO if there is only one index and one value, it is not too much work for caller to construct `int[] indices` and `double[] values` in one line. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java ########## @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests OneHotEncoder and OneHotEncoderModel. */ +public class OneHotEncoderTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Double, Vector>[] expectedOutput; + private OneHotEncoder estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("input"); + + predictTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("input"); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(2, 0, 1.0)); + put(1.0, Vectors.sparse(2, 1, 1.0)); + put(2.0, Vectors.sparse(2)); + } + } + }; + + estimator = new OneHotEncoder().setInputCols("input").setOutputCols("output"); + } + + /** + * Executes a given table and collect its results. Results are returned as a map array. Each + * element in the array is a map corresponding to a input column whose key is the original value + * in the input column, value is the one-hot encoding result of that value. + * + * @param table A table to be executed and to have its result collected + * @param inputCols Name of the input columns + * @param outputCols Name of the output columns containing one-hot encoding result + * @return A map containing the collected results + */ + private static Map<Double, Vector>[] executeAndCollect( + Table table, String[] inputCols, String[] outputCols) { + Map<Double, Vector>[] maps = new HashMap[inputCols.length]; + for (int i = 0; i < inputCols.length; i++) { + maps[i] = new HashMap<>(); + } + for (CloseableIterator<Row> it = table.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + for (int i = 0; i < inputCols.length; i++) { + maps[i].put( + ((Number) row.getField(inputCols[i])).doubleValue(), + (Vector) row.getField(outputCols[i])); + } + } + return maps; + } + + @Test + public void testParam() { + OneHotEncoder estimator = new OneHotEncoder(); + + assertTrue(estimator.getDropLast()); + + estimator.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); + + assertArrayEquals(new String[] {"test_input"}, estimator.getInputCols()); + assertArrayEquals(new String[] {"test_output"}, estimator.getOutputCols()); + assertFalse(estimator.getDropLast()); + + OneHotEncoderModel model = new OneHotEncoderModel(); + + assertTrue(model.getDropLast()); + + model.setInputCols("test_input").setOutputCols("test_output").setDropLast(false); + + assertArrayEquals(new String[] {"test_input"}, model.getInputCols()); + assertArrayEquals(new String[] {"test_output"}, model.getOutputCols()); + assertFalse(model.getDropLast()); + } + Review comment: Can we add a test similar to `NaiveBayesTest::testFeaturePredictionParam`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java ########## @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** A Model which classifies data using the model data computed by {@link NaiveBayes}. */ +public class NaiveBayesModel + implements Model<NaiveBayesModel>, NaiveBayesModelParams<NaiveBayesModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public NaiveBayesModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String predictionCol = getPredictionCol(); + final String featuresCol = getFeaturesCol(); + final String broadcastModelKey = "NaiveBayesModelStream"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Integer.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), predictionCol)); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + DataStream<NaiveBayesModelData> modelDataStream = Review comment: nits: could this be renamed as `modelData` for simplicity? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java ########## @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.api.common.typeinfo.TypeInfo; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfoFactory; +import org.apache.flink.util.Preconditions; + +import java.util.Arrays; +import java.util.Objects; + +/** A sparse vector of double values. */ +@TypeInfo(SparseVectorTypeInfoFactory.class) +public class SparseVector implements Vector { + public final int n; + public final int[] indices; + public final double[] values; + + public SparseVector() { + this(-1); + } + + public SparseVector(int n) { + this(n, new int[0], new double[0]); + } + + public SparseVector(int n, int index, double value) { + this(n, new int[] {index}, new double[] {value}); + } + + public SparseVector(int n, int[] indices, double[] values) { + this.n = n; + this.indices = indices; + this.values = values; + checkSizeAndIndicesRange(); + sortIndices(); Review comment: The overhead of sorting any array is non-trivial (with `O(nlogn` complexity) and it is better not to force this overhead for every SparseVector instantiation (e.g. if the given indices are already sorted). How about we move the sort logic to `Vectors.sparse(...)`. So that caller has the option to skip the sort overhead by calling `new SparseVector(...)` directly? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java ########## @@ -83,43 +101,26 @@ public void before() { dataTable = tEnv.fromDataStream(env.fromCollection(DATA), schema).as("features"); } - // Executes the graph and returns a map which maps points to clusterId. - private static Map<DenseVector, Integer> executeAndCollect( - Table output, String featureCol, String predictionCol) throws Exception { - StreamTableEnvironment tEnv = - (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); - - DataStream<Tuple2<DenseVector, Integer>> stream = - tEnv.toDataStream(output) - .map( - new MapFunction<Row, Tuple2<DenseVector, Integer>>() { - @Override - public Tuple2<DenseVector, Integer> map(Row row) { - return Tuple2.of( - (DenseVector) row.getField(featureCol), - (Integer) row.getField(predictionCol)); - } - }); - - List<Tuple2<DenseVector, Integer>> pointsWithClusterId = - IteratorUtils.toList(stream.executeAndCollect()); - - Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>(); - for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) { - clusterIdByPoints.put(entry.f0, entry.f1); - } - return clusterIdByPoints; - } - - private static void verifyClusteringResult( - Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>> groups) { - for (List<Integer> group : groups) { - for (int i = 1; i < group.size(); i++) { - assertEquals( - clusterIdByPoints.get(DATA.get(group.get(0))), - clusterIdByPoints.get(DATA.get(group.get(i)))); - } + /** + * Executes a table and collects its results. Results are returned as a list of sets, where + * elements in the same set are features whose prediction results are the same. + * + * @param table A table to be executed and to have its result collected + * @param featureCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + * @return A map containing the collected results Review comment: nits: a list of sets containing the collected results. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java ########## @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the one-hot encoding algorithm. + * + * <p>See https://en.wikipedia.org/wiki/One-hot. + */ +public class OneHotEncoder + implements Estimator<OneHotEncoder, OneHotEncoderModel>, + OneHotEncoderParams<OneHotEncoder> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public OneHotEncoder() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OneHotEncoderModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(getHandleInvalid().equals("ERROR")); + + final String[] inputCols = getInputCols(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple2<Integer, Integer>> modelDataStream = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractFeatureFunction(inputCols)) + .keyBy(x -> x.f0) + .window(EndOfStreamWindows.get()) + .reduce((x, y) -> new Tuple2<>(x.f0, Math.max(x.f1, y.f1))); + + OneHotEncoderModel model = + new OneHotEncoderModel() + .setModelData( + OneHotEncoderModelData.getModelDataTable(tEnv, modelDataStream)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static OneHotEncoder load(StreamExecutionEnvironment env, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class ExtractFeatureFunction + implements FlatMapFunction<Row, Tuple2<Integer, Integer>> { + private final String[] inputCols; + + private ExtractFeatureFunction(String[] inputCols) { + this.inputCols = inputCols; + } + + @Override + public void flatMap(Row row, Collector<Tuple2<Integer, Integer>> collector) + throws Exception { + Number number; + for (int i = 0; i < inputCols.length; i++) { + number = (Number) row.getField(inputCols[i]); Review comment: Could we move `Number number` into the loop body to follow the existing code style? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java ########## @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Vector; +import java.util.function.Function; + +/** + * A Model which encodes data into one-hot format using the model data computed by {@link + * OneHotEncoder}. + */ +public class OneHotEncoderModel + implements Model<OneHotEncoderModel>, OneHotEncoderParams<OneHotEncoderModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelTable; Review comment: nits: could we rename this as `modelDataTable`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/OneHotEncoderTest.java ########## @@ -0,0 +1,334 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoder; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModel; +import org.apache.flink.ml.feature.onehotencoder.OneHotEncoderModelData; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** Tests OneHotEncoder and OneHotEncoderModel. */ +public class OneHotEncoderTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Double, Vector>[] expectedOutput; + private OneHotEncoder estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0), Row.of(0.0)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("input"); + + predictTable = + tEnv.fromDataStream( + env.fromElements(Row.of(0.0), Row.of(1.0), Row.of(2.0)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("input"); + + expectedOutput = + new HashMap[] { + new HashMap<Double, Vector>() { + { + put(0.0, Vectors.sparse(2, 0, 1.0)); + put(1.0, Vectors.sparse(2, 1, 1.0)); + put(2.0, Vectors.sparse(2)); + } + } + }; + + estimator = new OneHotEncoder().setInputCols("input").setOutputCols("output"); + } + + /** + * Executes a given table and collect its results. Results are returned as a map array. Each + * element in the array is a map corresponding to a input column whose key is the original value + * in the input column, value is the one-hot encoding result of that value. + * + * @param table A table to be executed and to have its result collected + * @param inputCols Name of the input columns + * @param outputCols Name of the output columns containing one-hot encoding result + * @return A map containing the collected results Review comment: nits: an array of maps containing the collected results ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java ########## @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.SparseVector; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@code Sparse}. */ Review comment: Could you help change this to `@link Sparse`? Same for `DenseVectorSerializer`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoderModel.java ########## @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Vector; +import java.util.function.Function; + +/** + * A Model which encodes data into one-hot format using the model data computed by {@link + * OneHotEncoder}. + */ +public class OneHotEncoderModel + implements Model<OneHotEncoderModel>, OneHotEncoderParams<OneHotEncoderModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelTable; + + public OneHotEncoderModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + final String[] inputCols = getInputCols(); + final String[] outputCols = getOutputCols(); + final boolean dropLast = getDropLast(); + final String broadcastModelKey = "OneHotModelStream"; + + Preconditions.checkArgument(getHandleInvalid().equals("ERROR")); + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(inputCols.length == outputCols.length); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + Collections.nCopies( + outputCols.length, + ExternalTypeInfo.of(Vector.class)) + .toArray(new TypeInformation[0])), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelTable).getTableEnvironment(); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + DataStream<Tuple2<Integer, Integer>> modelStream = + OneHotEncoderModelData.getModelDataStream(tEnv, modelTable); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(broadcastModelKey, modelStream); + + Function<List<DataStream<?>>, DataStream<Row>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.transform( + PredictLabelOperator.class.getSimpleName(), + outputTypeInfo, + new PredictLabelOperator(inputCols, dropLast, broadcastModelKey)); + }; + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), broadcastMap, function); + + Table outputTable = tEnv.fromDataStream(output); + + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelTable).getTableEnvironment(); + + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<Tuple2<Integer, Integer>> sink = + FileSink.forRowFormat( + new Path(dataPath), new OneHotEncoderModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + OneHotEncoderModelData.getModelDataStream(tEnv, modelTable).sinkTo(sink); + + ReadWriteUtils.saveMetadata(this, path); + } + + public static OneHotEncoderModel load(StreamExecutionEnvironment env, String path) + throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + Source<Tuple2<Integer, Integer>, ?, ?> source = + FileSource.forRecordStreamFormat( + new OneHotEncoderModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + OneHotEncoderModel model = ReadWriteUtils.loadStageParam(path); + DataStream<Tuple2<Integer, Integer>> modelData = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData"); + model.setModelData(OneHotEncoderModelData.getModelDataTable(tEnv, modelData)); + + return model; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public OneHotEncoderModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelTable}; + } + + private static class PredictLabelOperator + extends AbstractUdfStreamOperator<Row, AbstractRichFunction> + implements OneInputStreamOperator<Row, Row> { + private final String[] inputCols; + private final boolean dropLast; + private final String broadcastModelKey; + + public PredictLabelOperator( + String[] inputCols, boolean dropLast, String broadcastModelKey) { + super(new AbstractRichFunction() {}); + this.inputCols = inputCols; + this.dropLast = dropLast; + this.broadcastModelKey = broadcastModelKey; + } + + @Override + public void processElement(StreamRecord<Row> streamRecord) { + Row row = streamRecord.getValue(); + List<Tuple2<Integer, Integer>> model = + userFunction.getRuntimeContext().getBroadcastVariable(broadcastModelKey); Review comment: Can we call `getRuntimeContext().getBroadcastVariable(broadcastModelKey)` directly? Could this class extends `AbstractStreamOperator` instead of `AbstractUdfStreamOperator` for simplicity? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/typeinfo/SparseVectorSerializer.java ########## @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.flink.ml.linalg.typeinfo; + +import org.apache.flink.api.common.typeutils.SimpleTypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot; +import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.ml.linalg.SparseVector; + +import java.io.IOException; +import java.util.Arrays; + +/** Specialized serializer for {@code Sparse}. */ +public final class SparseVectorSerializer extends TypeSerializerSingleton<SparseVector> { + + private static final long serialVersionUID = 1L; + + private static final SparseVectorSerializer INSTANCE = new SparseVectorSerializer(); + + @Override + public boolean isImmutableType() { + return false; + } + + @Override + public SparseVector createInstance() { + return new SparseVector(); + } + + @Override + public SparseVector copy(SparseVector from) { + return new SparseVector( + from.n, + Arrays.copyOf(from.indices, from.indices.length), + Arrays.copyOf(from.values, from.values.length)); + } + + @Override + public SparseVector copy(SparseVector from, SparseVector reuse) { + if (from.values.length == reuse.values.length && from.n == reuse.n) { + System.arraycopy(from.values, 0, reuse.values, 0, from.values.length); + System.arraycopy(from.indices, 0, reuse.indices, 0, from.indices.length); + return reuse; + } + return copy(from); + } + + @Override + public int getLength() { + return -1; + } + + @Override + public void serialize(SparseVector vector, DataOutputView target) throws IOException { + if (vector == null) { + throw new IllegalArgumentException("The vector must not be null."); + } + + target.writeInt(vector.n); + final int len = vector.values.length; + target.writeInt(len); + for (int i = 0; i < len; i++) { + target.writeInt(vector.indices[i]); + target.writeDouble(vector.values[i]); + } + } + + @Override + public SparseVector deserialize(DataInputView source) throws IOException { + int n = source.readInt(); + int len = source.readInt(); + int[] indices = new int[len]; + double[] values = new double[len]; + for (int i = 0; i < len; i++) { + indices[i] = source.readInt(); + values[i] = source.readDouble(); + } + return new SparseVector(n, indices, values); + } + + @Override + public SparseVector deserialize(SparseVector reuse, DataInputView source) throws IOException { + int n = source.readInt(); + int len = source.readInt(); + if (reuse.n == n && reuse.values.length == len) { + for (int i = 0; i < len; i++) { + reuse.indices[i] = source.readInt(); + reuse.values[i] = source.readDouble(); + } + return reuse; + } + + int[] indices = new int[len]; + double[] values = new double[len]; + for (int i = 0; i < len; i++) { + indices[i] = source.readInt(); Review comment: Could we put reusable logic in a static method, similar to `DenseVectorSerializer::readDoubleArray(...)`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java ########## @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** A Model which classifies data using the model data computed by {@link NaiveBayes}. */ +public class NaiveBayesModel + implements Model<NaiveBayesModel>, NaiveBayesModelParams<NaiveBayesModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public NaiveBayesModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String predictionCol = getPredictionCol(); + final String featuresCol = getFeaturesCol(); + final String broadcastModelKey = "NaiveBayesModelStream"; + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Integer.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), predictionCol)); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + DataStream<NaiveBayesModelData> modelDataStream = + NaiveBayesModelData.getModelDataStream(modelDataTable); + DataStream<Row> input = tEnv.toDataStream(inputs[0]); + + Map<String, DataStream<?>> broadcastMap = new HashMap<>(); + broadcastMap.put(broadcastModelKey, modelDataStream); + + Function<List<DataStream<?>>, DataStream<Row>> function = + dataStreams -> { + DataStream stream = dataStreams.get(0); + return stream.transform( + this.getClass().getSimpleName(), + outputTypeInfo, + new PredictLabelOperator(featuresCol, broadcastModelKey)); + }; + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(input), broadcastMap, function); + + Table outputTable = tEnv.fromDataStream(output); + + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + String dataPath = ReadWriteUtils.getDataPath(path); + FileSink<NaiveBayesModelData> sink = + FileSink.forRowFormat( + new Path(dataPath), new NaiveBayesModelData.ModelDataEncoder()) + .withRollingPolicy(OnCheckpointRollingPolicy.build()) + .withBucketAssigner(new BasePathBucketAssigner<>()) + .build(); + NaiveBayesModelData.getModelDataStream(modelDataTable).sinkTo(sink); + + ReadWriteUtils.saveMetadata(this, path); + } + + public static NaiveBayesModel load(StreamExecutionEnvironment env, String path) + throws IOException { + Source<NaiveBayesModelData, ?, ?> source = + FileSource.forRecordStreamFormat( + new NaiveBayesModelData.ModelDataStreamFormat(), + ReadWriteUtils.getDataPaths(path)) + .build(); + NaiveBayesModel model = ReadWriteUtils.loadStageParam(path); + DataStream<NaiveBayesModelData> modelData = + env.fromSource(source, WatermarkStrategy.noWatermarks(), "modelData"); + model.setModelData(NaiveBayesModelData.getModelDataTable(modelData)); + + return model; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public NaiveBayesModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + private static class PredictLabelOperator Review comment: This operator does not generate `label`. Should it be renamed as e.g. `GenerateOutputsOperator`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java ########## @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.onehotencoder; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** + * An Estimator which implements the one-hot encoding algorithm. + * + * <p>See https://en.wikipedia.org/wiki/One-hot. + */ +public class OneHotEncoder + implements Estimator<OneHotEncoder, OneHotEncoderModel>, + OneHotEncoderParams<OneHotEncoder> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public OneHotEncoder() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public OneHotEncoderModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument(getHandleInvalid().equals("ERROR")); + + final String[] inputCols = getInputCols(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple2<Integer, Integer>> modelDataStream = + tEnv.toDataStream(inputs[0]) + .flatMap(new ExtractFeatureFunction(inputCols)) + .keyBy(x -> x.f0) + .window(EndOfStreamWindows.get()) + .reduce((x, y) -> new Tuple2<>(x.f0, Math.max(x.f1, y.f1))); + + OneHotEncoderModel model = + new OneHotEncoderModel() + .setModelData( + OneHotEncoderModelData.getModelDataTable(tEnv, modelDataStream)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static OneHotEncoder load(StreamExecutionEnvironment env, String path) + throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static class ExtractFeatureFunction Review comment: nits: Could we have comments here explaining the semantic meaning of its output? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
