yunfengzhou-hub commented on a change in pull request #56: URL: https://github.com/apache/flink-ml/pull/56#discussion_r839507707
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type not support yet."); Review comment: nit: `"has not been supported yet"` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type not support yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; + default: + } + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec == null) { + throw new RuntimeException("VectorAssembler Error: vector data is null."); Review comment: nit: `Preconditions.checkNotNull(vec, "VectorAssembler Error:...")` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java ########## @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorassembler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; + +/** + * Vector assembler is a transformer that combines a given list of columns into a single vector + * column. It will combine raw features and features generated by different feature transformers + * into a single feature vector. The input features of this transformer must be a vector feature or + * a numerical feature. + */ +public class VectorAssembler + implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final double RATIO = 1.5; + + public VectorAssembler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), TypeInformation.of(Vector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new AssemblerFunc(getInputCols(), getHandleInvalid()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + private static class AssemblerFunc implements MapFunction<Row, Row> { + private final String[] inputCols; + private final String handleInvalid; + + public AssemblerFunc(String[] inputCols, String handleInvalid) { + this.inputCols = inputCols; + this.handleInvalid = handleInvalid; + } + + @Override + public Row map(Row value) { + Object[] objects = new Object[inputCols.length]; + for (int i = 0; i < objects.length; ++i) { + objects[i] = value.getField(inputCols[i]); + } + return Row.join(value, Row.of(assemble(objects, handleInvalid))); + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorAssembler load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + private static Vector assemble(Object[] objects, String handleInvalid) { + int offset = 0; + Map<Integer, Double> map = new LinkedHashMap<>(objects.length); + for (Object object : objects) { + try { + if (object instanceof Number) { + map.put(offset++, ((Number) object).doubleValue()); + } else if (object instanceof Vector) { + offset = appendVector((Vector) object, map, offset); + } else { + throw new UnsupportedOperationException( + "Vector assembler : input type not support yet."); + } + } catch (Exception e) { + switch (handleInvalid) { + case HasHandleInvalid.ERROR_INVALID: + throw new RuntimeException("Vector assembler failed.", e); + case HasHandleInvalid.SKIP_INVALID: + return null; + default: + } + } + } + + if (map.size() * RATIO > offset) { + DenseVector assembledVector = new DenseVector(offset); + for (int key : map.keySet()) { + assembledVector.values[key] = map.get(key); + } + return assembledVector; + } else { + return convertMapToSparseVector(offset, map); + } + } + + private static int appendVector(Vector vec, Map<Integer, Double> map, int offset) { + if (vec == null) { + throw new RuntimeException("VectorAssembler Error: vector data is null."); + } + if (vec instanceof SparseVector) { + SparseVector sparseVector = (SparseVector) vec; + if (sparseVector.size() <= 0) { + throw new RuntimeException("The append sparse vector must have a positive size."); Review comment: nit: `"The appended sparse ..."` ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java ########## @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.common.param.HasHandleInvalid; +import org.apache.flink.ml.feature.vectorassembler.VectorAssembler; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests VectorAssembler. */ +public class VectorAssemblerTest extends AbstractTestBase { + + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table inputDataTable; + + private static final List<Row> INPUT_DATA = + Arrays.asList( + Row.of( + 0, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse(5, new int[] {3}, new double[] {1.0})), + Row.of( + 1, + Vectors.dense(2.1, 3.1), + 1.0, + Vectors.sparse( + 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})), + Row.of(2, null, 1.0, null)); + + private static final SparseVector EXPECTED_DATA_1 = + Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0}); + private static final DenseVector EXPECTED_DATA_2 = + Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Schema schema = + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .column("f3", DataTypes.of(SparseVector.class)) + .build(); + DataStream<Row> dataStream = env.fromCollection(INPUT_DATA); + inputDataTable = + tEnv.fromDataStream(dataStream, schema).as("id", "vec", "num", "sparseVec"); + } + + private void verifyPredictionResult(Table output, String outputCol) throws Exception { + DataStream<Row> dataStream = tEnv.toDataStream(output); + List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect()); + assertEquals(3, results.size()); + for (Row result : results) { + if (result.getField(0) == (Object) 0) { + assertEquals(EXPECTED_DATA_1, result.getField(outputCol)); + } else if (result.getField(0) == (Object) 1) { + assertEquals(EXPECTED_DATA_2, result.getField(outputCol)); + } else { + assertNull(result.getField(outputCol)); + } + } + } + + @Test + public void testParam() { + VectorAssembler vectorAssembler = new VectorAssembler(); + assertEquals(HasHandleInvalid.ERROR_INVALID, vectorAssembler.getHandleInvalid()); + assertEquals("output", vectorAssembler.getOutputCol()); + vectorAssembler + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + assertArrayEquals(new String[] {"vec", "num", "sparseVec"}, vectorAssembler.getInputCols()); + assertEquals(HasHandleInvalid.SKIP_INVALID, vectorAssembler.getHandleInvalid()); + assertEquals("assembledVec", vectorAssembler.getOutputCol()); + } + + @Test + public void testTransform() throws Exception { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.SKIP_INVALID); + Table output = vectorAssembler.transform(inputDataTable)[0]; + verifyPredictionResult(output, vectorAssembler.getOutputCol()); + } + + @Test + public void testHandleInvalidOptions() { + VectorAssembler vectorAssembler = + new VectorAssembler() + .setInputCols("vec", "num", "sparseVec") + .setOutputCol("assembledVec") + .setHandleInvalid(HasHandleInvalid.ERROR_INVALID); + try { + Table outputTable = vectorAssembler.transform(inputDataTable)[0]; + outputTable.execute().collect().next(); + Assert.fail("Expected IllegalArgumentException"); + } catch (Exception e) { + assertEquals(RuntimeException.class, ((Throwable) e).getClass()); Review comment: Can we check the detailed exception message? For example in `LogisticRegressionTest` ```java assertEquals( "Multinomial classification is not supported yet. Supported options: [auto, binomial].", e.getCause().getCause().getMessage()); ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
