yunfengzhou-hub commented on code in PR #158: URL: https://github.com/apache/flink-ml/pull/158#discussion_r980634965
########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorParams.java: ########## @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.variancethresholdselector; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasOutputCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; + +/** + * * Params of VarianceThresholdSelectorModel. + * + * @param <T> The class type of this instance. + */ +public interface VarianceThresholdSelectorParams<T> extends HasFeaturesCol<T>, HasOutputCol<T> { + + Param<Double> VARIANCE_THRESHOLD = + new DoubleParam( + "varianceThreshold", + "Param for variance threshold. Features with " + + "a variance not greater than this threshold will be removed.", + 0.0); Review Comment: nit: we can add a `ParamValidators.gtEq(0)` here. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java: ########## @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.variancethresholdselector; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A Model which removes low-variance data using the model data computed by {@link + * VarianceThresholdSelector}. + */ +public class VarianceThresholdSelectorModel + implements Model<VarianceThresholdSelectorModel>, + VarianceThresholdSelectorParams<VarianceThresholdSelectorModel> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public VarianceThresholdSelectorModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public VarianceThresholdSelectorModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + VarianceThresholdSelectorModelData.getModelDataStream(modelDataTable), + path, + new VarianceThresholdSelectorModelData.ModelDataEncoder()); + } + + public static VarianceThresholdSelectorModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + VarianceThresholdSelectorModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new VarianceThresholdSelectorModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<VarianceThresholdSelectorModelData> varianceThresholdSelectorModel = + VarianceThresholdSelectorModelData.getModelDataStream(modelDataTable); + + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + TypeInformation.of(DenseVector.class)), Review Comment: nit: `DenseVectorTypeInfo` ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModelData.java: ########## @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.variancethresholdselector; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.List; + +/** + * Model data of {@link VarianceThresholdSelectorModel}. + * + * <p>This class also provides methods to convert model data from Table to a data stream, and + * classes to save/load model data. + */ +public class VarianceThresholdSelectorModelData { + + public int numOfFeatures; + public int[] indices; Review Comment: Do you think it would be better to filter the variances according to the threshold in the Model, instead of in the Estimator, and to keep variances of all indices in the model data? This can make users learn the detailed variances computed during the training process, and make one training process applicable to multiple models with just an adjustment to the threshold. If the current model data would be better, let's divide `VarianceThresholdSelectorParams` and `VarianceThresholdSelectorModelParams`, and make `VARIANCE_THRESHOLD` an exclusive parameter for the estimator. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java: ########## @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.variancethresholdselector; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A Model which removes low-variance data using the model data computed by {@link + * VarianceThresholdSelector}. + */ +public class VarianceThresholdSelectorModel + implements Model<VarianceThresholdSelectorModel>, + VarianceThresholdSelectorParams<VarianceThresholdSelectorModel> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public VarianceThresholdSelectorModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public VarianceThresholdSelectorModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + VarianceThresholdSelectorModelData.getModelDataStream(modelDataTable), + path, + new VarianceThresholdSelectorModelData.ModelDataEncoder()); + } + + public static VarianceThresholdSelectorModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + VarianceThresholdSelectorModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new VarianceThresholdSelectorModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<VarianceThresholdSelectorModelData> varianceThresholdSelectorModel = + VarianceThresholdSelectorModelData.getModelDataStream(modelDataTable); + + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + TypeInformation.of(DenseVector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol())); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, varianceThresholdSelectorModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictOutputFunction(getFeaturesCol(), broadcastModelKey), + outputTypeInfo); + }); + + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOutputFunction extends RichMapFunction<Row, Row> { + + private final String featureCol; + private final String broadcastKey; + private int expectedNumOfFeatures; + private int[] indices = null; + + public PredictOutputFunction(String featureCol, String broadcastKey) { + this.featureCol = featureCol; + this.broadcastKey = broadcastKey; + } + + @Override + public Row map(Row row) { + if (indices == null) { + VarianceThresholdSelectorModelData varianceThresholdSelectorModelData = + (VarianceThresholdSelectorModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + expectedNumOfFeatures = varianceThresholdSelectorModelData.numOfFeatures; + indices = varianceThresholdSelectorModelData.indices; + } + + if (indices.length == 0) { + return Row.join(row, Row.of(Vectors.dense())); + } else { + DenseVector inputVec = ((Vector) row.getField(featureCol)).toDense(); + Preconditions.checkArgument( Review Comment: It seems that this check is also applicable when `indices.length == 0`. Let's move this part of code to above the `if` condition. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java: ########## @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.variancethresholdselector; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.IntStream; + +/** + * An Estimator which implements the VarianceThresholdSelector algorithm. The algorithm removes all + * low-variance features. Features with a variance not greater than the threshold will be removed. + * The default is to keep all features with non-zero variance, i.e. remove the features that have + * the same value in all samples. + */ +public class VarianceThresholdSelector + implements Estimator<VarianceThresholdSelector, VarianceThresholdSelectorModel>, + VarianceThresholdSelectorParams<VarianceThresholdSelector> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public VarianceThresholdSelector() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public VarianceThresholdSelectorModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + final String featuresCol = getFeaturesCol(); + final double varianceThreshold = getVarianceThreshold(); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<DenseVector> inputData = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, DenseVector>) + value -> ((Vector) value.getField(featuresCol)).toDense()); + DataStream<DenseVector> varianceValues = + inputData + .transform( + "CalculateVarianceOperator", + inputData.getType(), + new VarianceThresholdSelector.VarianceFunctionOperator()) + .setParallelism(1); + DataStream<VarianceThresholdSelectorModelData> modelData = + DataStreamUtils.mapPartition( + varianceValues, + new RichMapPartitionFunction< + DenseVector, VarianceThresholdSelectorModelData>() { + @Override + public void mapPartition( + Iterable<DenseVector> values, + Collector<VarianceThresholdSelectorModelData> out) { + DenseVector varianceVector = values.iterator().next(); + int[] indices = + IntStream.range(0, varianceVector.size()) + .filter( + i -> + varianceVector.get(i) + > varianceThreshold) + .toArray(); + out.collect( + new VarianceThresholdSelectorModelData( + varianceVector.size(), indices)); + } + }); + + VarianceThresholdSelectorModel model = + new VarianceThresholdSelectorModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + /** + * A stream operator to compute the variance from feature column of the input bounded data + * stream. + */ + public static class VarianceFunctionOperator extends AbstractStreamOperator<DenseVector> + implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput { + + private ListState<DenseVector> sumState; + private ListState<DenseVector> squareSumState; + private ListState<Integer> numState; + + private int numRows = 0; + private DenseVector sumVector; + private DenseVector squareSumVector; + private DenseVector varianceVector; Review Comment: nit: This can be converted to a local variable. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java: ########## @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.variancethresholdselector; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.stream.IntStream; + +/** + * An Estimator which implements the VarianceThresholdSelector algorithm. The algorithm removes all + * low-variance features. Features with a variance not greater than the threshold will be removed. + * The default is to keep all features with non-zero variance, i.e. remove the features that have + * the same value in all samples. + */ +public class VarianceThresholdSelector + implements Estimator<VarianceThresholdSelector, VarianceThresholdSelectorModel>, + VarianceThresholdSelectorParams<VarianceThresholdSelector> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public VarianceThresholdSelector() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public VarianceThresholdSelectorModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + final String featuresCol = getFeaturesCol(); + final double varianceThreshold = getVarianceThreshold(); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<DenseVector> inputData = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, DenseVector>) + value -> ((Vector) value.getField(featuresCol)).toDense()); + DataStream<DenseVector> varianceValues = + inputData + .transform( + "CalculateVarianceOperator", + inputData.getType(), + new VarianceThresholdSelector.VarianceFunctionOperator()) Review Comment: nit: `new VarianceFunctionOperator()` is enough. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
