zhipeng93 commented on code in PR #131: URL: https://github.com/apache/flink-ml/pull/131#discussion_r929449313
########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java: ########## @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorslicer; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * A transformer that transforms a vector to a new one with a sub-array of the original features. It + * is useful for extracting features from a given vector. If the indices acquired from setIndices() + * are not in order, the indices of the result vector will be sorted. + */ +public class VectorSlicer implements Transformer<VectorSlicer>, VectorSlicerParams<VectorSlicer> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public VectorSlicer() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new VectorSlice(getIndices(), getInputCol()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorSlicer load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** Vector slice function. */ Review Comment: Could you update the java doc here and let it contains more information? `Vector slice function` contains zero information here. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java: ########## @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorslicer; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * A transformer that transforms a vector to a new one with a sub-array of the original features. It + * is useful for extracting features from a given vector. If the indices acquired from setIndices() + * are not in order, the indices of the result vector will be sorted. + */ +public class VectorSlicer implements Transformer<VectorSlicer>, VectorSlicerParams<VectorSlicer> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public VectorSlicer() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new VectorSlice(getIndices(), getInputCol()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorSlicer load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** Vector slice function. */ + private static class VectorSlice implements MapFunction<Row, Row> { + private final Integer[] indices; + private final String inputCol; + + public VectorSlice(Integer[] indices, String inputCol) { + this.indices = indices; + Arrays.sort(this.indices); Review Comment: nit: we could sort the indices outside this function, i.e., before constructing the job graph before Line#68. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java: ########## @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorslicer; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * A transformer that transforms a vector to a new one with a sub-array of the original features. It Review Comment: nit: A Transformer... ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java: ########## @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.vectorslicer; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * A transformer that transforms a vector to a new one with a sub-array of the original features. It + * is useful for extracting features from a given vector. If the indices acquired from setIndices() + * are not in order, the indices of the result vector will be sorted. + */ +public class VectorSlicer implements Transformer<VectorSlicer>, VectorSlicerParams<VectorSlicer> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public VectorSlicer() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + DataStream<Row> output = + tEnv.toDataStream(inputs[0]) + .map(new VectorSlice(getIndices(), getInputCol()), outputTypeInfo); + Table outputTable = tEnv.fromDataStream(output); + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static VectorSlicer load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** Vector slice function. */ + private static class VectorSlice implements MapFunction<Row, Row> { + private final Integer[] indices; + private final String inputCol; + + public VectorSlice(Integer[] indices, String inputCol) { + this.indices = indices; + Arrays.sort(this.indices); + this.inputCol = inputCol; + } + + @Override + public Row map(Row row) throws Exception { + Vector inputVec = row.getFieldAs(inputCol); + Vector outputVec; + if (inputVec instanceof DenseVector) { + double[] values = new double[indices.length]; + for (int i = 0; i < indices.length; ++i) { + if (indices[i] >= inputVec.size()) { Review Comment: The check at Line#107 and Line#119 could be moved to Line#102 since the indices is sorted for readbility and efficiency. Also how about we update the error message as `Index value is greater than vector size: + inputVec.size`? ########## flink-ml-python/pyflink/ml/lib/feature/vectorslicer.py: ########## @@ -0,0 +1,70 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Tuple +from pyflink.ml.core.wrapper import JavaWithParams +from pyflink.ml.core.param import IntArrayParam +from pyflink.ml.lib.feature.common import JavaFeatureTransformer +from pyflink.ml.lib.param import HasInputCol, HasOutputCol, ParamValidators, Param + + +class _VectorSlicerParams( + JavaWithParams, + HasInputCol, + HasOutputCol +): + """ + Params for :class:`VectorSlicer`. + """ + + INDICES: Param[Tuple[int, ...]] = IntArrayParam( + "indices", + "An array of indices to select features from a vector column.", + None, + ParamValidators.non_empty_array()) + + def __init__(self, java_params): + super(_VectorSlicerParams, self).__init__(java_params) + + def set_indices(self, *ind: int): + return self.set(self.INDICES, ind) + + def get_indices(self) -> Tuple[int, ...]: + return self.get(self.INDICES) + + @property + def indices(self) -> Tuple[int, ...]: + return self.get_indices() + + +class VectorSlicer(JavaFeatureTransformer, _VectorSlicerParams): + """ + A feature transformer that transforms a vector to a new one with a sub-array of the original Review Comment: nit: make the python doc consistent with java doc. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
