This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 0769392 [FLINK-25616] Add Transformer for VectorAssembler
0769392 is described below
commit 076939285ffb7e5167371ae4433a1b4f394e6753
Author: weibo <[email protected]>
AuthorDate: Sat Apr 2 15:32:36 2022 +0800
[FLINK-25616] Add Transformer for VectorAssembler
This closes #56.
---
.../apache/flink/ml/common/param/HasOutputCol.java | 39 +++++
.../feature/vectorassembler/VectorAssembler.java | 182 +++++++++++++++++++++
.../vectorassembler/VectorAssemblerParams.java | 63 +++++++
.../flink/ml/feature/VectorAssemblerTest.java | 178 ++++++++++++++++++++
4 files changed, 462 insertions(+)
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java
new file mode 100644
index 0000000..e191058
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasOutputCol.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared outputCol param. */
+public interface HasOutputCol<T> extends WithParams<T> {
+ Param<String> OUTPUT_COL =
+ new StringParam(
+ "outputCol", "Output column name.", "output",
ParamValidators.notNull());
+
+ default String getOutputCol() {
+ return get(OUTPUT_COL);
+ }
+
+ default T setOutputCol(String value) {
+ return set(OUTPUT_COL, value);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
new file mode 100644
index 0000000..61d84d6
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorassembler;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+/**
+ * A feature transformer that combines a given list of input columns into a
vector column. Types of
+ * input columns must be either vector or numerical value.
+ */
+public class VectorAssembler
+ implements Transformer<VectorAssembler>,
VectorAssemblerParams<VectorAssembler> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private static final double RATIO = 1.5;
+
+ public VectorAssembler() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
TypeInformation.of(Vector.class)),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getOutputCol()));
+ DataStream<Row> output =
+ tEnv.toDataStream(inputs[0])
+ .flatMap(
+ new AssemblerFunc(getInputCols(),
getHandleInvalid()),
+ outputTypeInfo);
+ Table outputTable = tEnv.fromDataStream(output);
+ return new Table[] {outputTable};
+ }
+
+ private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+ private final String[] inputCols;
+ private final String handleInvalid;
+
+ public AssemblerFunc(String[] inputCols, String handleInvalid) {
+ this.inputCols = inputCols;
+ this.handleInvalid = handleInvalid;
+ }
+
+ @Override
+ public void flatMap(Row value, Collector<Row> out) throws Exception {
+ try {
+ Object[] objects = new Object[inputCols.length];
+ for (int i = 0; i < objects.length; ++i) {
+ objects[i] = value.getField(inputCols[i]);
+ }
+ Vector assembledVector = assemble(objects);
+ out.collect(Row.join(value, Row.of(assembledVector)));
+ } catch (Exception e) {
+ switch (handleInvalid) {
+ case VectorAssemblerParams.ERROR_INVALID:
+ throw e;
+ case VectorAssemblerParams.SKIP_INVALID:
+ return;
+ case VectorAssemblerParams.KEEP_INVALID:
+ out.collect(Row.join(value, Row.of((Object) null)));
+ return;
+ default:
+ throw new UnsupportedOperationException(
+ "handleInvalid=" + handleInvalid + " is not
supported");
+ }
+ }
+ }
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static VectorAssembler load(StreamTableEnvironment env, String
path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ private static Vector assemble(Object[] objects) {
+ int offset = 0;
+ Map<Integer, Double> map = new LinkedHashMap<>(objects.length);
+ for (Object object : objects) {
+ Preconditions.checkNotNull(object, "Input column value should not
be null.");
+ if (object instanceof Number) {
+ map.put(offset++, ((Number) object).doubleValue());
+ } else if (object instanceof Vector) {
+ offset = appendVector((Vector) object, map, offset);
+ } else {
+ throw new IllegalArgumentException("Input type has not been
supported yet.");
+ }
+ }
+
+ if (map.size() * RATIO > offset) {
+ DenseVector assembledVector = new DenseVector(offset);
+ for (int key : map.keySet()) {
+ assembledVector.values[key] = map.get(key);
+ }
+ return assembledVector;
+ } else {
+ return convertMapToSparseVector(offset, map);
+ }
+ }
+
+ private static int appendVector(Vector vec, Map<Integer, Double> map, int
offset) {
+ if (vec instanceof SparseVector) {
+ SparseVector sparseVector = (SparseVector) vec;
+ int[] indices = sparseVector.indices;
+ double[] values = sparseVector.values;
+ for (int i = 0; i < indices.length; ++i) {
+ map.put(offset + indices[i], values[i]);
+ }
+ offset += sparseVector.size();
+ } else {
+ DenseVector denseVector = (DenseVector) vec;
+ for (int i = 0; i < denseVector.size(); ++i) {
+ map.put(offset++, denseVector.values[i]);
+ }
+ }
+ return offset;
+ }
+
+ private static SparseVector convertMapToSparseVector(int size,
Map<Integer, Double> map) {
+ int[] indices = new int[map.size()];
+ double[] values = new double[map.size()];
+ int offset = 0;
+ for (Map.Entry<Integer, Double> entry : map.entrySet()) {
+ indices[offset] = entry.getKey();
+ values[offset++] = entry.getValue();
+ }
+ return new SparseVector(size, indices, values);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
new file mode 100644
index 0000000..5e2cda4
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorassembler;
+
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params of VectorAssembler.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorAssemblerParams<T> extends HasInputCols<T>,
HasOutputCol<T> {
+
+ String ERROR_INVALID = "error";
+ String SKIP_INVALID = "skip";
+ String KEEP_INVALID = "keep";
+
+ /**
+ * Supported options and the corresponding behavior to handle invalid
entries is listed as
+ * follows.
+ *
+ * <ul>
+ * <li>error: raise an exception.
+ * <li>skip: filter out rows with bad values.
+ * <li>keep: output bad rows with output column's value set to null.
+ * </ul>
+ */
+ Param<String> HANDLE_INVALID =
+ new StringParam(
+ "handleInvalid",
+ "Strategy to handle invalid entries.",
+ ERROR_INVALID,
+ ParamValidators.inArray(ERROR_INVALID, SKIP_INVALID,
KEEP_INVALID));
+
+ default String getHandleInvalid() {
+ return get(HANDLE_INVALID);
+ }
+
+ default T setHandleInvalid(String value) {
+ set(HANDLE_INVALID, value);
+ return (T) this;
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
new file mode 100644
index 0000000..193077c
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.vectorassembler.VectorAssembler;
+import org.apache.flink.ml.feature.vectorassembler.VectorAssemblerParams;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.StageTestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+
+/** Tests VectorAssembler. */
+public class VectorAssemblerTest extends AbstractTestBase {
+
+ private StreamTableEnvironment tEnv;
+ private Table inputDataTable;
+
+ private static final List<Row> INPUT_DATA =
+ Arrays.asList(
+ Row.of(
+ 0,
+ Vectors.dense(2.1, 3.1),
+ 1.0,
+ Vectors.sparse(5, new int[] {3}, new double[]
{1.0})),
+ Row.of(
+ 1,
+ Vectors.dense(2.1, 3.1),
+ 1.0,
+ Vectors.sparse(
+ 5, new int[] {4, 2, 3, 1}, new double[]
{4.0, 2.0, 3.0, 1.0})),
+ Row.of(2, null, null, null));
+
+ private static final SparseVector EXPECTED_OUTPUT_DATA_1 =
+ Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1,
1.0, 1.0});
+ private static final DenseVector EXPECTED_OUTPUT_DATA_2 =
+ Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+ inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec",
"num", "sparseVec");
+ }
+
+ private void verifyOutputResult(Table output, String outputCol, int
outputSize)
+ throws Exception {
+ DataStream<Row> dataStream = tEnv.toDataStream(output);
+ List<Row> results =
IteratorUtils.toList(dataStream.executeAndCollect());
+ assertEquals(outputSize, results.size());
+ for (Row result : results) {
+ if (result.getField(0) == (Object) 0) {
+ assertEquals(EXPECTED_OUTPUT_DATA_1,
result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 1) {
+ assertEquals(EXPECTED_OUTPUT_DATA_2,
result.getField(outputCol));
+ } else {
+ assertNull(result.getField(outputCol));
+ }
+ }
+ }
+
+ @Test
+ public void testParam() {
+ VectorAssembler vectorAssembler = new VectorAssembler();
+ assertEquals(HasHandleInvalid.ERROR_INVALID,
vectorAssembler.getHandleInvalid());
+ assertEquals("output", vectorAssembler.getOutputCol());
+ vectorAssembler
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+ assertArrayEquals(new String[] {"vec", "num", "sparseVec"},
vectorAssembler.getInputCols());
+ assertEquals(HasHandleInvalid.SKIP_INVALID,
vectorAssembler.getHandleInvalid());
+ assertEquals("assembledVec", vectorAssembler.getOutputCol());
+ }
+
+ @Test
+ public void testKeepInvalid() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setHandleInvalid(VectorAssemblerParams.KEEP_INVALID);
+ Table output = vectorAssembler.transform(inputDataTable)[0];
+ assertEquals(
+ Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
+ output.getResolvedSchema().getColumnNames());
+ verifyOutputResult(output, vectorAssembler.getOutputCol(), 3);
+ }
+
+ @Test
+ public void testErrorInvalid() {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+ try {
+ Table outputTable = vectorAssembler.transform(inputDataTable)[0];
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Exception e) {
+ assertEquals(
+ "Input column value should not be null.",
+
e.getCause().getCause().getCause().getCause().getCause().getMessage());
+ }
+ }
+
+ @Test
+ public void testSkipInvalid() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setHandleInvalid(VectorAssemblerParams.SKIP_INVALID);
+ Table output = vectorAssembler.transform(inputDataTable)[0];
+ assertEquals(
+ Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
+ output.getResolvedSchema().getColumnNames());
+ verifyOutputResult(output, vectorAssembler.getOutputCol(), 2);
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+ VectorAssembler loadedVectorAssembler =
+ StageTestUtils.saveAndReload(
+ tEnv, vectorAssembler,
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+ Table output = loadedVectorAssembler.transform(inputDataTable)[0];
+ verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2);
+ }
+}