This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new de08eb8  [FLINK-27096] Optimize OneHotEncoder performance
de08eb8 is described below

commit de08eb802e9a00ecf2cbc1ce38d648ef8fec1f5c
Author: yunfengzhou-hub <[email protected]>
AuthorDate: Tue Jun 21 20:16:44 2022 +0800

    [FLINK-27096] Optimize OneHotEncoder performance
    
    This closes #113.
---
 .../datagenerator/common/DoubleGenerator.java      |  27 +++-
 .../main/resources/onehotencoder-benchmark.json    |  35 +++++
 .../ml/feature/onehotencoder/OneHotEncoder.java    | 171 ++++++++++++++++-----
 3 files changed, 195 insertions(+), 38 deletions(-)

diff --git 
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
 
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
index 3dffe52..f4a1ba4 100644
--- 
a/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
+++ 
b/flink-ml-benchmark/src/main/java/org/apache/flink/ml/benchmark/datagenerator/common/DoubleGenerator.java
@@ -21,6 +21,9 @@ package org.apache.flink.ml.benchmark.datagenerator.common;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
 import org.apache.flink.types.Row;
 import org.apache.flink.util.Preconditions;
 
@@ -29,11 +32,29 @@ import java.util.Arrays;
 /** A DataGenerator which creates a table of doubles. */
 public class DoubleGenerator extends InputTableGenerator<DoubleGenerator> {
 
+    public static final Param<Integer> ARITY =
+            new IntParam(
+                    "arity",
+                    "Arity of the generated double values. "
+                            + "If set to positive value, each feature would be 
an integer in range [0, arity - 1]. "
+                            + "If set to zero, each feature would be a 
continuous double in range [0, 1).",
+                    0,
+                    ParamValidators.gtEq(0));
+
+    public int getArity() {
+        return get(ARITY);
+    }
+
+    public DoubleGenerator setArity(int value) {
+        return set(ARITY, value);
+    }
+
     @Override
     protected RowGenerator[] getRowGenerators() {
         String[][] colNames = getColNames();
         Preconditions.checkState(colNames.length == 1);
         int numOutputCols = colNames[0].length;
+        int arity = getArity();
 
         return new RowGenerator[] {
             new RowGenerator(getNumValues(), getSeed()) {
@@ -41,7 +62,11 @@ public class DoubleGenerator extends 
InputTableGenerator<DoubleGenerator> {
                 public Row nextRow() {
                     Row r = new Row(numOutputCols);
                     for (int i = 0; i < numOutputCols; i++) {
-                        r.setField(i, random.nextDouble());
+                        if (arity > 0) {
+                            r.setField(i, (double) random.nextInt(arity));
+                        } else {
+                            r.setField(i, random.nextDouble());
+                        }
                     }
                     return r;
                 }
diff --git a/flink-ml-benchmark/src/main/resources/onehotencoder-benchmark.json 
b/flink-ml-benchmark/src/main/resources/onehotencoder-benchmark.json
new file mode 100644
index 0000000..6195cc2
--- /dev/null
+++ b/flink-ml-benchmark/src/main/resources/onehotencoder-benchmark.json
@@ -0,0 +1,35 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements.  See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License.  You may obtain a copy of the License at
+//
+//    http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+{
+  "version": 1,
+  "OneHotEncoder": {
+    "stage": {
+      "className": "org.apache.flink.ml.feature.onehotencoder.OneHotEncoder",
+      "paramMap": {
+        "inputCols":  ["input"],
+        "outputCols":  ["output"]
+      }
+    },
+    "inputData": {
+      "className": 
"org.apache.flink.ml.benchmark.datagenerator.common.DoubleGenerator",
+      "paramMap": {
+        "colNames": [["input"]],
+        "arity": 10,
+        "numValues": 100000
+      }
+    }
+  }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
index 9bc1bb7..0543799 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/onehotencoder/OneHotEncoder.java
@@ -18,24 +18,35 @@
 
 package org.apache.flink.ml.feature.onehotencoder;
 
-import org.apache.flink.api.common.functions.FlatMapFunction;
-import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
 import org.apache.flink.ml.api.Estimator;
-import org.apache.flink.ml.common.datastream.DataStreamUtils;
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.util.ParamUtils;
 import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
-import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
 import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -68,13 +79,20 @@ public class OneHotEncoder
 
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
-        DataStream<Tuple2<Integer, Integer>> columnsAndValues =
-                tEnv.toDataStream(inputs[0]).flatMap(new 
ExtractInputColsValueFunction(inputCols));
+        DataStream<Integer[]> localMaxIndices =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "ExtractInputValueAndFindMaxIndexOperator",
+                                
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO),
+                                new 
ExtractInputValueAndFindMaxIndexOperator(inputCols));
 
         DataStream<Tuple2<Integer, Integer>> modelData =
-                DataStreamUtils.mapPartition(
-                        columnsAndValues.keyBy(columnIdAndValue -> 
columnIdAndValue.f0),
-                        new FindMaxIndexFunction());
+                localMaxIndices
+                        .transform(
+                                "GenerateModelDataOperator",
+                                
TupleTypeInfo.getBasicTupleTypeInfo(Integer.class, Integer.class),
+                                new GenerateModelDataOperator())
+                        .setParallelism(1);
 
         OneHotEncoderModel model =
                 new 
OneHotEncoderModel().setModelData(tEnv.fromDataStream(modelData));
@@ -97,50 +115,129 @@ public class OneHotEncoder
     }
 
     /**
-     * Extract values of input columns of input data.
-     *
-     * <p>Input: rows of input data containing designated input columns
-     *
-     * <p>Output: Pairs of column index and value stored in those columns
+     * Operator to extract the integer values from input columns and to find 
the max index value for
+     * each column.
      */
-    private static class ExtractInputColsValueFunction
-            implements FlatMapFunction<Row, Tuple2<Integer, Integer>> {
+    private static class ExtractInputValueAndFindMaxIndexOperator
+            extends AbstractStreamOperator<Integer[]>
+            implements OneInputStreamOperator<Row, Integer[]>, BoundedOneInput 
{
+
         private final String[] inputCols;
 
-        private ExtractInputColsValueFunction(String[] inputCols) {
+        private ListState<Integer[]> maxIndicesState;
+
+        private Integer[] maxIndices;
+
+        private ExtractInputValueAndFindMaxIndexOperator(String[] inputCols) {
             this.inputCols = inputCols;
         }
 
         @Override
-        public void flatMap(Row row, Collector<Tuple2<Integer, Integer>> 
collector) {
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<Integer[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO);
+
+            maxIndicesState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("maxIndices", type));
+
+            maxIndices =
+                    OperatorStateUtils.getUniqueElement(maxIndicesState, 
"maxIndices")
+                            .orElse(initMaxIndices());
+        }
+
+        private Integer[] initMaxIndices() {
+            Integer[] indices = new Integer[inputCols.length];
+            Arrays.fill(indices, Integer.MIN_VALUE);
+            return indices;
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            maxIndicesState.update(Collections.singletonList(maxIndices));
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) {
+            Row row = streamRecord.getValue();
             for (int i = 0; i < inputCols.length; i++) {
                 Number number = (Number) row.getField(inputCols[i]);
-                Preconditions.checkArgument(
-                        number.intValue() == number.doubleValue(),
-                        String.format("Value %s cannot be parsed as indexed 
integer.", number));
-                Preconditions.checkArgument(
-                        number.intValue() >= 0, "Negative value not 
supported.");
-                collector.collect(new Tuple2<>(i, number.intValue()));
+                int value = number.intValue();
+
+                if (value != number.doubleValue()) {
+                    throw new IllegalArgumentException(
+                            String.format("Value %s cannot be parsed as 
indexed integer.", number));
+                }
+                Preconditions.checkArgument(value >= 0, "Negative value not 
supported.");
+
+                if (value > maxIndices[i]) {
+                    maxIndices[i] = value;
+                }
             }
         }
+
+        @Override
+        public void endInput() {
+            output.collect(new StreamRecord<>(maxIndices));
+        }
     }
 
-    /** Function to find the max index value for each column. */
-    private static class FindMaxIndexFunction
-            implements MapPartitionFunction<Tuple2<Integer, Integer>, 
Tuple2<Integer, Integer>> {
+    /**
+     * Collects and reduces the max index value in each column and produces 
the model data.
+     *
+     * <p>Output: Pairs of column index and max index value in this column.
+     */
+    private static class GenerateModelDataOperator
+            extends AbstractStreamOperator<Tuple2<Integer, Integer>>
+            implements OneInputStreamOperator<Integer[], Tuple2<Integer, 
Integer>>,
+                    BoundedOneInput {
+
+        private ListState<Integer[]> maxIndicesState;
+
+        private Integer[] maxIndices;
 
         @Override
-        public void mapPartition(
-                Iterable<Tuple2<Integer, Integer>> iterable,
-                Collector<Tuple2<Integer, Integer>> collector) {
-            Map<Integer, Integer> map = new HashMap<>();
-            for (Tuple2<Integer, Integer> value : iterable) {
-                map.put(
-                        value.f0,
-                        Math.max(map.getOrDefault(value.f0, 
Integer.MIN_VALUE), value.f1));
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<Integer[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(BasicTypeInfo.INT_TYPE_INFO);
+
+            maxIndicesState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("maxIndices", type));
+
+            maxIndices =
+                    OperatorStateUtils.getUniqueElement(maxIndicesState, 
"maxIndices").orElse(null);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws 
Exception {
+            super.snapshotState(context);
+            maxIndicesState.update(Collections.singletonList(maxIndices));
+        }
+
+        @Override
+        public void processElement(StreamRecord<Integer[]> streamRecord) {
+            if (maxIndices == null) {
+                maxIndices = streamRecord.getValue();
+            } else {
+                Integer[] indices = streamRecord.getValue();
+                for (int i = 0; i < maxIndices.length; i++) {
+                    if (indices[i] > maxIndices[i]) {
+                        maxIndices[i] = indices[i];
+                    }
+                }
             }
-            for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
-                collector.collect(new Tuple2<>(entry.getKey(), 
entry.getValue()));
+        }
+
+        @Override
+        public void endInput() {
+            for (int i = 0; i < maxIndices.length; i++) {
+                output.collect(new StreamRecord<>(Tuple2.of(i, 
maxIndices[i])));
             }
         }
     }

Reply via email to