This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new de08eb8 [FLINK-27096] Optimize OneHotEncoder performance
de08eb8 is described below
commit de08eb802e9a00ecf2cbc1ce38d648ef8fec1f5c
Author: yunfengzhou-hub <[email protected]>
AuthorDate: Tue Jun 21 20:16:44 2022 +0800
[FLINK-27096] Optimize OneHotEncoder performance
This closes #113.
---
.../datagenerator/common/DoubleGenerator.java | 27 +++-
.../main/resources/onehotencoder-benchmark.json | 35 +++++
.../ml/feature/onehotencoder/OneHotEncoder.java | 171 ++++++++++++++++-----
3 files changed, 195 insertions(+), 38 deletions(-)
diff --git
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
index 3dffe52..f4a1ba4 100644
---
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
+++
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
@@ -21,6 +21,9 @@ package org.apache.flink.ml.benchmark.datagenerator.common;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
@@ -29,11 +32,29 @@ import java.util.Arrays;
/** A DataGenerator which creates a table of doubles. */
public class DoubleGenerator extends InputTableGenerator<DoubleGenerator> {
+ public static final Param<Integer> ARITY =
+ new IntParam(
+ "arity",
+ "Arity of the generated double values. "
+ + "If set to positive value, each feature would be
an integer in range [0, arity - 1]. "
+ + "If set to zero, each feature would be a
continuous double in range [0, 1).",
+ 0,
+ ParamValidators.gtEq(0));
+
+ public int getArity() {
+ return get(ARITY);
+ }
+
+ public DoubleGenerator setArity(int value) {
+ return set(ARITY, value);
+ }
+
@Override
protected RowGenerator[] getRowGenerators() {
String[][] colNames = getColNames();
Preconditions.checkState(colNames.length == 1);
int numOutputCols = colNames[0].length;
+ int arity = getArity();
return new RowGenerator[] {
new RowGenerator(getNumValues(), getSeed()) {
@@ -41,7 +62,11 @@ public class DoubleGenerator extends
InputTableGenerator<DoubleGenerator> {
public Row nextRow() {
Row r = new Row(numOutputCols);
for (int i = 0; i < numOutputCols; i++) {
- r.setField(i, random.nextDouble());
+ if (arity > 0) {
+ r.setField(i, (double) random.nextInt(arity));
+ } else {
+ r.setField(i, random.nextDouble());
+ }
}
return r;
}
diff --git a/flink-ml-benchmark/src/main/resources/onehotencoder-benchmark.json
b/flink-ml-benchmark/src/main/resources/onehotencoder-benchmark.json
new file mode 100644
index 0000000..6195cc2
--- /dev/null
+++ b/flink-ml-benchmark/src/main/resources/onehotencoder-benchmark.json
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+{
+ "version": 1,
+ "OneHotEncoder": {
+ "stage": {
+ "className": "org.apache.flink.ml.feature.onehotencoder.OneHotEncoder",
+ "paramMap": {
+ "inputCols": ["input"],
+ "outputCols": ["output"]
+ }
+ },
+ "inputData": {
+ "className":
"org.apache.flink.ml.benchmark.datagenerator.common.DoubleGenerator",
+ "paramMap": {
+ "colNames": [["input"]],
+ "arity": 10,
+ "numValues": 100000
+ }
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
index 9bc1bb7..0543799 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
@@ -18,24 +18,35 @@
package org.apache.flink.ml.feature.onehotencoder;
-import org.apache.flink.api.common.functions.FlatMapFunction;
-import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
import org.apache.flink.ml.api.Estimator;
-import org.apache.flink.ml.common.datastream.DataStreamUtils;
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
-import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;
import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -68,13 +79,20 @@ public class OneHotEncoder
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
- DataStream<Tuple2<Integer, Integer>> columnsAndValues =
- tEnv.toDataStream(inputs[0]).flatMap(new
ExtractInputColsValueFunction(inputCols));
+ DataStream<Integer[]> localMaxIndices =
+ tEnv.toDataStream(inputs[0])
+ .transform(
+ "ExtractInputValueAndFindMaxIndexOperator",
+
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO),
+ new
ExtractInputValueAndFindMaxIndexOperator(inputCols));
DataStream<Tuple2<Integer, Integer>> modelData =
- DataStreamUtils.mapPartition(
- columnsAndValues.keyBy(columnIdAndValue ->
columnIdAndValue.f0),
- new FindMaxIndexFunction());
+ localMaxIndices
+ .transform(
+ "GenerateModelDataOperator",
+
TupleTypeInfo.getBasicTupleTypeInfo(Integer.class, Integer.class),
+ new GenerateModelDataOperator())
+ .setParallelism(1);
OneHotEncoderModel model =
new
OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData));
@@ -97,50 +115,129 @@ public class OneHotEncoder
}
/**
- * Extract values of input columns of input data.
- *
- * <p>Input: rows of input data containing designated input columns
- *
- * <p>Output: Pairs of column index and value stored in those columns
+ * Operator to extract the integer values from input columns and to find
the max index value for
+ * each column.
*/
- private static class ExtractInputColsValueFunction
- implements FlatMapFunction<Row, Tuple2<Integer, Integer>> {
+ private static class ExtractInputValueAndFindMaxIndexOperator
+ extends AbstractStreamOperator<Integer[]>
+ implements OneInputStreamOperator<Row, Integer[]>, BoundedOneInput
{
+
private final String[] inputCols;
- private ExtractInputColsValueFunction(String[] inputCols) {
+ private ListState<Integer[]> maxIndicesState;
+
+ private Integer[] maxIndices;
+
+ private ExtractInputValueAndFindMaxIndexOperator(String[] inputCols) {
this.inputCols = inputCols;
}
@Override
- public void flatMap(Row row, Collector<Tuple2<Integer, Integer>>
collector) {
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+
+ TypeInformation<Integer[]> type =
+
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO);
+
+ maxIndicesState =
+ context.getOperatorStateStore()
+ .getListState(new
ListStateDescriptor<>("maxIndices", type));
+
+ maxIndices =
+ OperatorStateUtils.getUniqueElement(maxIndicesState,
"maxIndices")
+ .orElse(initMaxIndices());
+ }
+
+ private Integer[] initMaxIndices() {
+ Integer[] indices = new Integer[inputCols.length];
+ Arrays.fill(indices, Integer.MIN_VALUE);
+ return indices;
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ maxIndicesState.update(Collections.singletonList(maxIndices));
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> streamRecord) {
+ Row row = streamRecord.getValue();
for (int i = 0; i < inputCols.length; i++) {
Number number = (Number) row.getField(inputCols[i]);
- Preconditions.checkArgument(
- number.intValue() == number.doubleValue(),
- String.format("Value %s cannot be parsed as indexed
integer.", number));
- Preconditions.checkArgument(
- number.intValue() >= 0, "Negative value not
supported.");
- collector.collect(new Tuple2<>(i, number.intValue()));
+ int value = number.intValue();
+
+ if (value != number.doubleValue()) {
+ throw new IllegalArgumentException(
+ String.format("Value %s cannot be parsed as
indexed integer.", number));
+ }
+ Preconditions.checkArgument(value >= 0, "Negative value not
supported.");
+
+ if (value > maxIndices[i]) {
+ maxIndices[i] = value;
+ }
}
}
+
+ @Override
+ public void endInput() {
+ output.collect(new StreamRecord<>(maxIndices));
+ }
}
- /** Function to find the max index value for each column. */
- private static class FindMaxIndexFunction
- implements MapPartitionFunction<Tuple2<Integer, Integer>,
Tuple2<Integer, Integer>> {
+ /**
+ * Collects and reduces the max index value in each column and produces
the model data.
+ *
+ * <p>Output: Pairs of column index and max index value in this column.
+ */
+ private static class GenerateModelDataOperator
+ extends AbstractStreamOperator<Tuple2<Integer, Integer>>
+ implements OneInputStreamOperator<Integer[], Tuple2<Integer,
Integer>>,
+ BoundedOneInput {
+
+ private ListState<Integer[]> maxIndicesState;
+
+ private Integer[] maxIndices;
@Override
- public void mapPartition(
- Iterable<Tuple2<Integer, Integer>> iterable,
- Collector<Tuple2<Integer, Integer>> collector) {
- Map<Integer, Integer> map = new HashMap<>();
- for (Tuple2<Integer, Integer> value : iterable) {
- map.put(
- value.f0,
- Math.max(map.getOrDefault(value.f0,
Integer.MIN_VALUE), value.f1));
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+
+ TypeInformation<Integer[]> type =
+
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO);
+
+ maxIndicesState =
+ context.getOperatorStateStore()
+ .getListState(new
ListStateDescriptor<>("maxIndices", type));
+
+ maxIndices =
+ OperatorStateUtils.getUniqueElement(maxIndicesState,
"maxIndices").orElse(null);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ maxIndicesState.update(Collections.singletonList(maxIndices));
+ }
+
+ @Override
+ public void processElement(StreamRecord<Integer[]> streamRecord) {
+ if (maxIndices == null) {
+ maxIndices = streamRecord.getValue();
+ } else {
+ Integer[] indices = streamRecord.getValue();
+ for (int i = 0; i < maxIndices.length; i++) {
+ if (indices[i] > maxIndices[i]) {
+ maxIndices[i] = indices[i];
+ }
+ }
}
- for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
- collector.collect(new Tuple2<>(entry.getKey(),
entry.getValue()));
+ }
+
+ @Override
+ public void endInput() {
+ for (int i = 0; i < maxIndices.length; i++) {
+ output.collect(new StreamRecord<>(Tuple2.of(i,
maxIndices[i])));
}
}
}