yunfengzhou-hub commented on code in PR #166: URL: https://github.com/apache/flink-ml/pull/166#discussion_r1009085375
########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java: ########## @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.imputer; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.util.QuantileSummary; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * The imputer for completing missing values of the input columns. + * + * <p>Missing values can be imputed using the statistics(mean, median or most frequent) of each + * column in which the missing values are located. The input columns should be of numeric type. Review Comment: Could you please add test cases to verify that this algorithm can work on numeric values other than doubles, like integers or floats? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/ImputerModel.java: ########## @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.imputer; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** A Model which replace the missing values using the model data computed by {@link Imputer}. */ +public class ImputerModel implements Model<ImputerModel>, ImputerModelParams<ImputerModel> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public ImputerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public ImputerModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + ImputerModelData.getModelDataStream(modelDataTable), + path, + new ImputerModelData.ModelDataEncoder()); + } + + public static ImputerModel load(StreamTableEnvironment tEnv, String path) throws IOException { + ImputerModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData(tEnv, path, new ImputerModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + String[] inputCols = getInputCols(); + String[] outputCols = getOutputCols(); + Preconditions.checkArgument(inputCols.length == outputCols.length); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]); + DataStream<ImputerModelData> imputerModel = + ImputerModelData.getModelDataStream(modelDataTable); + + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + TypeInformation<?>[] outputTypes = new TypeInformation[outputCols.length]; + Arrays.fill(outputTypes, BasicTypeInfo.DOUBLE_TYPE_INFO); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), outputTypes), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCols)); + + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(dataStream), + Collections.singletonMap(broadcastModelKey, imputerModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictOutputFunction( + getMissingValue(), inputCols, broadcastModelKey), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + /** This operator loads model data and predicts result. */ + private static class PredictOutputFunction extends RichMapFunction<Row, Row> { + + private final String[] inputCols; + private final String broadcastKey; + private final double missingValue; + private Map<String, Double> surrogates; + + public PredictOutputFunction(double missingValue, String[] inputCols, String broadcastKey) { + this.missingValue = missingValue; + this.inputCols = inputCols; + this.broadcastKey = broadcastKey; + } + + @Override + public Row map(Row row) throws Exception { + if (surrogates == null) { + ImputerModelData imputerModelData = + (ImputerModelData) + getRuntimeContext().getBroadcastVariable(broadcastKey).get(0); + surrogates = imputerModelData.surrogates; + Arrays.stream(inputCols) + .forEach( + col -> + Preconditions.checkArgument( + surrogates.containsKey(col), + "Column %s is unacceptable for the Imputer model.", + col)); + } + + Row outputRow = new Row(inputCols.length); + for (int i = 0; i < inputCols.length; i++) { + Double value = (Double) row.getField(i); + boolean shouldReplace; + if (Double.isNaN(missingValue)) { + shouldReplace = value == null || Double.isNaN(value); + } else { + shouldReplace = value == null || value == missingValue; + } Review Comment: `boolean shouldReplace = value == null || value.equals(missingValue);` should be enough. Double.NaN is not a corner case. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java: ########## @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param relativeError. */ +public interface HasRelativeError<T> extends WithParams<T> { + Param<Double> RELATIVE_ERROR = + new DoubleParam( + "relativeError", + "The relative target precision for the approximate quantile algorithm. Must be in the range (0, 1).", + 0.001, + ParamValidators.inRange(0, 1, false, false)); Review Comment: In Spark, 0 and 1 are also valid values for this parameter. Should we also support them in flink ml? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/imputer/Imputer.java: ########## @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.imputer; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.common.util.QuantileSummary; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * The imputer for completing missing values of the input columns. + * + * <p>Missing values can be imputed using the statistics(mean, median or most frequent) of each + * column in which the missing values are located. The input columns should be of numeric type. + * + * <p>Note that the mean/median/most_frequent value is computed after filtering out missing values. + * All null values in the input columns are also treated as missing, and so are imputed. + */ +public class Imputer implements Estimator<Imputer, ImputerModel>, ImputerParams<Imputer> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Imputer() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public ImputerModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + Preconditions.checkArgument( + getInputCols().length == getOutputCols().length, + "Num of input columns and output columns are inconsistent."); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> inputData = tEnv.toDataStream(inputs[0]); + + DataStream<ImputerModelData> modelData; + switch (getStrategy()) { + case MEAN: + modelData = + DataStreamUtils.aggregate( + inputData, + new MeanStrategyAggregator(getInputCols(), getMissingValue())); + break; + case MEDIAN: + modelData = + DataStreamUtils.aggregate( + inputData, + new MedianStrategyAggregator( + getInputCols(), getMissingValue(), getRelativeError())); + break; + case MOST_FREQUENT: + modelData = + DataStreamUtils.aggregate( + inputData, + new MostFrequentStrategyAggregator( + getInputCols(), getMissingValue())); + break; + default: + throw new RuntimeException("Unsupported strategy of Imputer: " + getStrategy()); + } + ImputerModel model = new ImputerModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + /** + * A stream operator to compute the mean value of all input columns of the input bounded data + * stream. + */ + private static class MeanStrategyAggregator + implements AggregateFunction<Row, Map<String, Tuple2<Double, Long>>, ImputerModelData> { + + private final String[] columnNames; + private final double missingValue; + + public MeanStrategyAggregator(String[] columnNames, double missingValue) { + this.columnNames = columnNames; + this.missingValue = missingValue; + } + + @Override + public Map<String, Tuple2<Double, Long>> createAccumulator() { + Map<String, Tuple2<Double, Long>> accumulators = new HashMap<>(); + Arrays.stream(columnNames).forEach(x -> accumulators.put(x, Tuple2.of(0.0, 0L))); + return accumulators; + } + + @Override + public Map<String, Tuple2<Double, Long>> add( + Row row, Map<String, Tuple2<Double, Long>> accumulators) { + accumulators.forEach( + (col, sumAndNum) -> { + Double rawValue = (Double) row.getField(col); + boolean shouldBypass = + rawValue == null + || Double.isNaN(rawValue) + || rawValue == missingValue; + if (!shouldBypass) { + sumAndNum.f0 += rawValue; + sumAndNum.f1 += 1; + } + }); + return accumulators; + } + + @Override + public ImputerModelData getResult(Map<String, Tuple2<Double, Long>> map) { + long numRows = map.entrySet().stream().findFirst().get().getValue().f1; + Preconditions.checkState( + numRows > 0, "The training set is empty or does not contains valid data."); + + Map<String, Double> surrogates = new HashMap<>(); + map.forEach((col, sumAndNum) -> surrogates.put(col, sumAndNum.f0 / sumAndNum.f1)); + return new ImputerModelData(surrogates); + } + + @Override + public Map<String, Tuple2<Double, Long>> merge( + Map<String, Tuple2<Double, Long>> acc1, Map<String, Tuple2<Double, Long>> acc2) { + Preconditions.checkArgument(acc1.size() == acc2.size()); + + acc1.forEach( + (col, numAndSum) -> { + acc2.get(col).f0 += numAndSum.f0; + acc2.get(col).f1 += numAndSum.f1; + }); + return acc2; + } + } + + /** + * A stream operator to compute the median value of all input columns of the input bounded data + * stream. + */ + private static class MedianStrategyAggregator + implements AggregateFunction<Row, Map<String, QuantileSummary>, ImputerModelData> { + private final String[] columnNames; + private final double missingValue; + private final double relativeError; + + public MedianStrategyAggregator( + String[] columnNames, double missingValue, double relativeError) { + this.columnNames = columnNames; + this.missingValue = missingValue; + this.relativeError = relativeError; + } + + @Override + public Map<String, QuantileSummary> createAccumulator() { + Map<String, QuantileSummary> summaries = new HashMap<>(); + Arrays.stream(columnNames) + .forEach(x -> summaries.put(x, new QuantileSummary(relativeError))); + return summaries; + } + + @Override + public Map<String, QuantileSummary> add(Row row, Map<String, QuantileSummary> summaries) { + summaries.forEach( + (col, summary) -> { + Double rawValue = (Double) row.getField(col); + boolean shouldBypass = + rawValue == null + || Double.isNaN(rawValue) + || rawValue == missingValue; + if (!shouldBypass) { + summary.insert(rawValue); + } + }); + return summaries; + } + + @Override + public ImputerModelData getResult(Map<String, QuantileSummary> summaries) { + Preconditions.checkState( + !summaries.entrySet().stream().findFirst().get().getValue().isEmpty(), + "The training set is empty or does not contains valid data."); + + Map<String, Double> surrogates = new HashMap<>(); + summaries.forEach( + (col, summary) -> { + QuantileSummary compressed = summary.compress(); + double median = compressed.query(0.5); + surrogates.put(col, median); + }); + return new ImputerModelData(surrogates); + } + + @Override + public Map<String, QuantileSummary> merge( + Map<String, QuantileSummary> acc1, Map<String, QuantileSummary> acc2) { + Preconditions.checkArgument(acc1.size() == acc2.size()); + + acc1.forEach( + (col, summary1) -> { + QuantileSummary summary2 = acc2.get(col).compress(); + acc2.put(col, summary2.merge(summary1.compress())); + }); + return acc2; + } + } + + /** + * A stream operator to compute the most frequent value of all input columns of the input + * bounded data stream. + */ + private static class MostFrequentStrategyAggregator + implements AggregateFunction<Row, Map<String, Map<Double, Long>>, ImputerModelData> { + private final String[] columnNames; + private final double missingValue; + + public MostFrequentStrategyAggregator(String[] columnNames, double missingValue) { + this.columnNames = columnNames; + this.missingValue = missingValue; + } + + @Override + public Map<String, Map<Double, Long>> createAccumulator() { + Map<String, Map<Double, Long>> accumulators = new HashMap<>(); + Arrays.stream(columnNames).forEach(x -> accumulators.put(x, new HashMap<>())); + return accumulators; + } + + @Override + public Map<String, Map<Double, Long>> add( + Row row, Map<String, Map<Double, Long>> accumulators) { + accumulators.forEach( + (col, counts) -> { + Double rawValue = (Double) row.getField(col); + boolean shouldBypass = + rawValue == null + || Double.isNaN(rawValue) + || rawValue == missingValue; + if (!shouldBypass) { + double value = rawValue; + if (counts.containsKey(value)) { + counts.put(value, counts.get(value) + 1); + } else { + counts.put(value, 1L); + } + } + }); + return accumulators; + } + + @Override + public ImputerModelData getResult(Map<String, Map<Double, Long>> map) { + long validColumns = + map.entrySet().stream().filter(x -> x.getValue().size() > 0).count(); + Preconditions.checkState( + validColumns > 0, "The training set is empty or does not contains valid data."); + + Map<String, Double> surrogates = new HashMap<>(); + map.forEach( + (col, counts) -> { + long maxCnt = Long.MIN_VALUE; + double value = Double.NaN; + for (Map.Entry<Double, Long> entry : counts.entrySet()) { + if (maxCnt < entry.getValue()) { + maxCnt = entry.getValue(); + value = entry.getKey(); Review Comment: Let's guarantee that when there are multiple max counts, the smallest value would be selected. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasRelativeError.java: ########## @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for shared param relativeError. */ +public interface HasRelativeError<T> extends WithParams<T> { + Param<Double> RELATIVE_ERROR = + new DoubleParam( + "relativeError", + "The relative target precision for the approximate quantile algorithm. Must be in the range (0, 1).", Review Comment: I understand that spark also has the description "Must be in the range...", but since it is not common practice to add ParamValidator's description in Param's description in Flink ML, I'm not sure if it would be better if we remove the last sentence. ########## docs/content/docs/operators/feature/imputer.md: ########## @@ -0,0 +1,188 @@ +--- +title: "Imputer" +weight: 1 +type: docs +aliases: +- /operators/feature/imputer.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions dand limitations +under the License. +--> + +## Imputer +The imputer for completing missing values of the input columns. +Missing values can be imputed using the statistics(mean, median or most frequent) of each column in which the missing values are located. The input columns should be of numeric type. + +__Note__ The mean/median/most_frequent value is computed after filtering out missing values and null values, null values are always treated as missing, and so are also imputed. + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:--------|:------------------------| +| inputCols | Number | `null` | Features to be imputed. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:-------|:--------|:------------------| +| outputCols | Double | `null` | Imputed features. | + +### Parameters + +Below are the parameters required by `ImputerModel`. + +| Key | Default | Type | Required | Description | +|---------------|-----------|-------------|----------|--------------------------------------------------------------------------------| +| inputCols | `null` | String[] | yes | Input column names. | +| outputCols | `null` | String[] | yes | Output column names. | +| strategy | `"mean"` | String | no | The imputation strategy. Supported values: 'mean', 'median', 'most_frequent'. | +| relativeError | `0.001` | Double | no | The relative target precision, only effective when the strategy is 'median'. | Review Comment: Let's keep the description the same across JavaDoc and markdown documents. If we want to make the point that `relativeError` only works with certain imputation strategy, we can place this point to the javadoc and the markdown description of the whole class. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
