lindong28 commented on a change in pull request #54: URL: https://github.com/apache/flink-ml/pull/54#discussion_r831052100
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */ +public class MinMaxScalerTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainDataTable; + private Table predictDataTable; + private static final List<Row> trainData = + new ArrayList<>( + Arrays.asList( + Row.of(Vectors.dense(0.0, 3.0)), + Row.of(Vectors.dense(2.1, 0.0)), + Row.of(Vectors.dense(4.1, 5.1)), + Row.of(Vectors.dense(6.1, 8.1)), + Row.of(Vectors.dense(200, 400)))); + private static final List<Row> predictData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(150.0, 90.0)))); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + trainDataTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features"); + predictDataTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); + } + + private static void verifyPredictionResult(Table output, String outputCol, DenseVector expected) + throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + DataStream<DenseVector> stream = + tEnv.toDataStream(output) + .map( + (MapFunction<Row, DenseVector>) + row -> (DenseVector) row.getField(outputCol)); + List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect()); + assertEquals(1, result.size()); + assertEquals(expected, result.get(0)); + } + + @Test + public void testParam() { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + assertEquals("features", minMaxScaler.getFeaturesCol()); + assertEquals("prediction", minMaxScaler.getPredictionCol()); + assertEquals(1.0, minMaxScaler.getMax(), 0.0001); + assertEquals(0.0, minMaxScaler.getMin(), 0.0001); + minMaxScaler + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMax(4.0) + .setMin(1.0); + assertEquals("test_features", minMaxScaler.getFeaturesCol()); + assertEquals(1.0, minMaxScaler.getMin(), 0.0001); + assertEquals(4.0, minMaxScaler.getMax(), 0.0001); + assertEquals("test_output", minMaxScaler.getPredictionCol()); + } + + @Test + public void testFeaturePredictionParam() { + MinMaxScaler minMaxScaler = + new MinMaxScaler() + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMin(1.0) + .setMax(4.0); + + MinMaxScalerModel model = minMaxScaler.fit(trainDataTable.as("test_features")); + Table output = model.transform(predictDataTable.as("test_features"))[0]; + assertEquals( + Arrays.asList("test_features", "test_output"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFewerSamplesThanParallel() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel model = minMaxScaler.fit(predictDataTable); + Table result = model.transform(predictDataTable)[0]; + verifyPredictionResult(result, minMaxScaler.getPredictionCol(), Vectors.dense(0.5, 0.5)); + } + + @Test + public void testMaxValueEqualsMinValueButPredictValueNotEquals() throws Exception { + List<Row> userDefineData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(30.0, 50.0)))); + Table userDefineDataTable = + tEnv.fromDataStream(env.fromCollection(userDefineData)).as("features"); + MinMaxScaler minMaxScaler = new MinMaxScaler().setMax(10.0).setMin(0.0); + MinMaxScalerModel model = minMaxScaler.fit(predictDataTable); + Table result = model.transform(userDefineDataTable)[0]; + verifyPredictionResult(result, minMaxScaler.getPredictionCol(), Vectors.dense(5.0, 5.0)); + } + + @Test + public void testFitAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table output = minMaxScalerModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), Vectors.dense(0.75, 0.225)); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScaler loadedMinMaxScaler = + StageTestUtils.saveAndReload( + env, minMaxScaler, tempFolder.newFolder().getAbsolutePath()); + MinMaxScalerModel minMaxScalerModel = loadedMinMaxScaler.fit(trainDataTable); + minMaxScalerModel = Review comment: nits: should we replace `minMaxScalerModel = ...` with `MinMaxScaler loadedMinMaxScalerModel = ...` so that the name is consistent with `loadedKmeans` used above? This would also be consistent with the naming pattern used in other tests (e.g. `KMeansTest::testSaveLoadAndPredict`). If the name is too long, feel free to simplify the name to `model` and `loadedModel`. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */ +public class MinMaxScalerTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainDataTable; + private Table predictDataTable; + private static final List<Row> trainData = + new ArrayList<>( + Arrays.asList( + Row.of(Vectors.dense(0.0, 3.0)), + Row.of(Vectors.dense(2.1, 0.0)), + Row.of(Vectors.dense(4.1, 5.1)), + Row.of(Vectors.dense(6.1, 8.1)), + Row.of(Vectors.dense(200, 400)))); + private static final List<Row> predictData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(150.0, 90.0)))); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + trainDataTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features"); + predictDataTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); + } + + private static void verifyPredictionResult(Table output, String outputCol, DenseVector expected) + throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + DataStream<DenseVector> stream = + tEnv.toDataStream(output) + .map( + (MapFunction<Row, DenseVector>) + row -> (DenseVector) row.getField(outputCol)); + List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect()); + assertEquals(1, result.size()); + assertEquals(expected, result.get(0)); + } + + @Test + public void testParam() { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + assertEquals("features", minMaxScaler.getFeaturesCol()); + assertEquals("prediction", minMaxScaler.getPredictionCol()); + assertEquals(1.0, minMaxScaler.getMax(), 0.0001); + assertEquals(0.0, minMaxScaler.getMin(), 0.0001); + minMaxScaler + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMax(4.0) + .setMin(1.0); + assertEquals("test_features", minMaxScaler.getFeaturesCol()); + assertEquals(1.0, minMaxScaler.getMin(), 0.0001); + assertEquals(4.0, minMaxScaler.getMax(), 0.0001); + assertEquals("test_output", minMaxScaler.getPredictionCol()); + } + + @Test + public void testFeaturePredictionParam() { + MinMaxScaler minMaxScaler = + new MinMaxScaler() + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMin(1.0) + .setMax(4.0); + + MinMaxScalerModel model = minMaxScaler.fit(trainDataTable.as("test_features")); + Table output = model.transform(predictDataTable.as("test_features"))[0]; + assertEquals( + Arrays.asList("test_features", "test_output"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFewerSamplesThanParallel() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel model = minMaxScaler.fit(predictDataTable); + Table result = model.transform(predictDataTable)[0]; + verifyPredictionResult(result, minMaxScaler.getPredictionCol(), Vectors.dense(0.5, 0.5)); + } + + @Test + public void testMaxValueEqualsMinValueButPredictValueNotEquals() throws Exception { + List<Row> userDefineData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(30.0, 50.0)))); + Table userDefineDataTable = + tEnv.fromDataStream(env.fromCollection(userDefineData)).as("features"); + MinMaxScaler minMaxScaler = new MinMaxScaler().setMax(10.0).setMin(0.0); + MinMaxScalerModel model = minMaxScaler.fit(predictDataTable); + Table result = model.transform(userDefineDataTable)[0]; + verifyPredictionResult(result, minMaxScaler.getPredictionCol(), Vectors.dense(5.0, 5.0)); + } + + @Test + public void testFitAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table output = minMaxScalerModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), Vectors.dense(0.75, 0.225)); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScaler loadedMinMaxScaler = + StageTestUtils.saveAndReload( + env, minMaxScaler, tempFolder.newFolder().getAbsolutePath()); + MinMaxScalerModel minMaxScalerModel = loadedMinMaxScaler.fit(trainDataTable); + minMaxScalerModel = + StageTestUtils.saveAndReload( + env, minMaxScalerModel, tempFolder.newFolder().getAbsolutePath()); + assertEquals( + Arrays.asList("minVector", "maxVector"), + minMaxScalerModel.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = minMaxScalerModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), Vectors.dense(0.75, 0.225)); + } + + @Test + public void testGetModelData() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table modelData = minMaxScalerModel.getModelData()[0]; + DataStream<Row> output = tEnv.toDataStream(modelData); + assertEquals("minVector", modelData.getResolvedSchema().getColumnNames().get(0)); Review comment: nits: could we replace these two lines with the following one line? This would make the code a bit simpler and also more extensible if in the future we want to test schemas with 3+ fields. ``` assertEquals( Arrays.asList("minVector", "maxVector"), modelData.getResolvedSchema().getColumnNames()); ``` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java ########## @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * An Estimator which implements the MinMaxScaler algorithm. This algorithm rescales feature values + * to a common range [min, max] which defined by user. + * + * <blockquote> + * + * $$ Rescaled(value) = \frac{value - E_{min}}{E_{max} - E_{min}} * (max - min) + min $$ + * + * </blockquote> + * + * <p>For the case \(E_{max} == E_{min}\), \(Rescaled(value) = 0.5 * (max + min)\). + * + * <p>See https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization). + */ +public class MinMaxScaler + implements Estimator<MinMaxScaler, MinMaxScalerModel>, MinMaxScalerParams<MinMaxScaler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public MinMaxScaler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public MinMaxScalerModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + final String featureCol = getFeaturesCol(); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<DenseVector> features = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, DenseVector>) + value -> (DenseVector) value.getField(featureCol)); + DataStream<DenseVector> minMaxValues = + features.transform( + "reduceInEachPartition", + features.getType(), + new MinMaxReduceFunctionOperator()) + .transform( + "reduceInFinalPartition", + features.getType(), + new MinMaxReduceFunctionOperator()) + .setParallelism(1); + DataStream<MinMaxScalerModelData> modelData = + DataStreamUtils.mapPartition( + minMaxValues, + new RichMapPartitionFunction<DenseVector, MinMaxScalerModelData>() { + @Override + public void mapPartition( + Iterable<DenseVector> values, + Collector<MinMaxScalerModelData> out) { + Iterator<DenseVector> iter = values.iterator(); + DenseVector minVector = iter.next(); + DenseVector maxVector = iter.next(); + out.collect(new MinMaxScalerModelData(minVector, maxVector)); + } + }); + + MinMaxScalerModel model = + new MinMaxScalerModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + /** + * A stream operator to compute the min and max values in each partition of the input bounded + * data stream. + */ + private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector> + implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput { + private ListState<DenseVector> minState; + private ListState<DenseVector> maxState; + + private DenseVector minVector; + private DenseVector maxVector; + + @Override + public void endInput() { + if (minVector != null) { + output.collect(new StreamRecord<>(minVector)); + output.collect(new StreamRecord<>(maxVector)); + } + } + + @Override + public void processElement(StreamRecord<DenseVector> streamRecord) { + DenseVector currentValue = streamRecord.getValue(); + if (minVector == null) { + int vecSize = currentValue.size(); + minVector = new DenseVector(vecSize); + maxVector = new DenseVector(vecSize); + System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize); + System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize); + } else { + Preconditions.checkArgument( + currentValue.size() == maxVector.size(), + "CurrentValue should has same size with maxVector."); + for (int i = 0; i < currentValue.size(); ++i) { + minVector.values[i] = Math.min(minVector.values[i], currentValue.values[i]); + maxVector.values[i] = Math.max(maxVector.values[i], currentValue.values[i]); + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + minState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( Review comment: Hmm.. instead of using `getOperatorConfig().getTypeSerializerIn(...)`, would it be better to use the code below for simplicity and consistency with other usages of `getListState(...)` in Flink ML? ``` new ListStateDescriptor<>("minState", TypeInformation.of(DenseVector.class)) ``` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java ########## @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}. + */ +public class MinMaxScalerModel + implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public MinMaxScalerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public MinMaxScalerModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<MinMaxScalerModelData> minMaxScalerModel = + MinMaxScalerModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + ExternalTypeInfo.of(DenseVector.class)), Review comment: Hmm.. is there any reason in particular to use `ExternalTypeInfo`? Would it be better to use `TypeInformation.of(DenseVector.class)` for consistency with e.g. `LogisticRegressionModel::transform(...)`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java ########## @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params for {@link MinMaxScaler}. + * + * @param <T> The class type of this instance. + */ +public interface MinMaxScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> { + Param<Double> MAX = + new DoubleParam( + "max", "Upper bound after transformation.", 1.0, ParamValidators.notNull()); Review comment: Spark uses `upper bound of the output feature range` as the description for this parameter. It useful to explicitly specify what this upper bound applies to (i.e. `the output`). In comparison, `Upper bound after transformation` does not explicitly specify this information. Do you think it would be a bit better to use Spark's description here? Same for the `min` parameter. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerParams.java ########## @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; + +/** + * Params for {@link MinMaxScaler}. + * + * @param <T> The class type of this instance. + */ +public interface MinMaxScalerParams<T> extends HasFeaturesCol<T>, HasPredictionCol<T> { + Param<Double> MAX = + new DoubleParam( + "max", "Upper bound after transformation.", 1.0, ParamValidators.notNull()); + + default Double getMax() { + return get(MAX); + } + + default T setMax(Double value) { + return set(MAX, value); + } + + Param<Double> MIN = Review comment: nits: Could we move the definition of this MIN parameter to be above MAX? This would be more consistent with the Java coding style, where we typically declare all class member variables before declaring methods. And it would also make it easier for users to get all the parameters defined in this class. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */ +public class MinMaxScalerTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainDataTable; + private Table predictDataTable; + private static final List<Row> trainData = + new ArrayList<>( + Arrays.asList( + Row.of(Vectors.dense(0.0, 3.0)), + Row.of(Vectors.dense(2.1, 0.0)), + Row.of(Vectors.dense(4.1, 5.1)), + Row.of(Vectors.dense(6.1, 8.1)), + Row.of(Vectors.dense(200, 400)))); + private static final List<Row> predictData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(150.0, 90.0)))); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + trainDataTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features"); + predictDataTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); + } + + private static void verifyPredictionResult(Table output, String outputCol, DenseVector expected) + throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + DataStream<DenseVector> stream = + tEnv.toDataStream(output) + .map( + (MapFunction<Row, DenseVector>) + row -> (DenseVector) row.getField(outputCol)); + List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect()); + assertEquals(1, result.size()); + assertEquals(expected, result.get(0)); + } + + @Test + public void testParam() { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + assertEquals("features", minMaxScaler.getFeaturesCol()); + assertEquals("prediction", minMaxScaler.getPredictionCol()); + assertEquals(1.0, minMaxScaler.getMax(), 0.0001); + assertEquals(0.0, minMaxScaler.getMin(), 0.0001); + minMaxScaler + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMax(4.0) Review comment: Super nits: Given that we already use min before max at line 111, any chance we can consistently use `*min*` before `*max*` for a bit extra readability and consistency? There are a few other places where we use max before min. It will be update them for consistency. Feel free to leave it as is and I can submit a followup PR to make this change. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/feature/MinMaxScalerTest.java ########## @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScaler; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModel; +import org.apache.flink.ml.feature.minmaxscaler.MinMaxScalerModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** Tests {@link MinMaxScaler} and {@link MinMaxScalerModel}. */ +public class MinMaxScalerTest { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainDataTable; + private Table predictDataTable; + private static final List<Row> trainData = + new ArrayList<>( + Arrays.asList( + Row.of(Vectors.dense(0.0, 3.0)), + Row.of(Vectors.dense(2.1, 0.0)), + Row.of(Vectors.dense(4.1, 5.1)), + Row.of(Vectors.dense(6.1, 8.1)), + Row.of(Vectors.dense(200, 400)))); + private static final List<Row> predictData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(150.0, 90.0)))); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + trainDataTable = tEnv.fromDataStream(env.fromCollection(trainData)).as("features"); + predictDataTable = tEnv.fromDataStream(env.fromCollection(predictData)).as("features"); + } + + private static void verifyPredictionResult(Table output, String outputCol, DenseVector expected) + throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + DataStream<DenseVector> stream = + tEnv.toDataStream(output) + .map( + (MapFunction<Row, DenseVector>) + row -> (DenseVector) row.getField(outputCol)); + List<DenseVector> result = IteratorUtils.toList(stream.executeAndCollect()); + assertEquals(1, result.size()); + assertEquals(expected, result.get(0)); + } + + @Test + public void testParam() { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + assertEquals("features", minMaxScaler.getFeaturesCol()); + assertEquals("prediction", minMaxScaler.getPredictionCol()); + assertEquals(1.0, minMaxScaler.getMax(), 0.0001); + assertEquals(0.0, minMaxScaler.getMin(), 0.0001); + minMaxScaler + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMax(4.0) + .setMin(1.0); + assertEquals("test_features", minMaxScaler.getFeaturesCol()); + assertEquals(1.0, minMaxScaler.getMin(), 0.0001); + assertEquals(4.0, minMaxScaler.getMax(), 0.0001); + assertEquals("test_output", minMaxScaler.getPredictionCol()); + } + + @Test + public void testFeaturePredictionParam() { + MinMaxScaler minMaxScaler = + new MinMaxScaler() + .setFeaturesCol("test_features") + .setPredictionCol("test_output") + .setMin(1.0) + .setMax(4.0); + + MinMaxScalerModel model = minMaxScaler.fit(trainDataTable.as("test_features")); + Table output = model.transform(predictDataTable.as("test_features"))[0]; + assertEquals( + Arrays.asList("test_features", "test_output"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFewerSamplesThanParallel() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel model = minMaxScaler.fit(predictDataTable); + Table result = model.transform(predictDataTable)[0]; + verifyPredictionResult(result, minMaxScaler.getPredictionCol(), Vectors.dense(0.5, 0.5)); + } + + @Test + public void testMaxValueEqualsMinValueButPredictValueNotEquals() throws Exception { + List<Row> userDefineData = + new ArrayList<>(Collections.singletonList(Row.of(Vectors.dense(30.0, 50.0)))); + Table userDefineDataTable = + tEnv.fromDataStream(env.fromCollection(userDefineData)).as("features"); + MinMaxScaler minMaxScaler = new MinMaxScaler().setMax(10.0).setMin(0.0); + MinMaxScalerModel model = minMaxScaler.fit(predictDataTable); + Table result = model.transform(userDefineDataTable)[0]; + verifyPredictionResult(result, minMaxScaler.getPredictionCol(), Vectors.dense(5.0, 5.0)); + } + + @Test + public void testFitAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table output = minMaxScalerModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), Vectors.dense(0.75, 0.225)); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScaler loadedMinMaxScaler = + StageTestUtils.saveAndReload( + env, minMaxScaler, tempFolder.newFolder().getAbsolutePath()); + MinMaxScalerModel minMaxScalerModel = loadedMinMaxScaler.fit(trainDataTable); + minMaxScalerModel = + StageTestUtils.saveAndReload( + env, minMaxScalerModel, tempFolder.newFolder().getAbsolutePath()); + assertEquals( + Arrays.asList("minVector", "maxVector"), + minMaxScalerModel.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = minMaxScalerModel.transform(predictDataTable)[0]; + verifyPredictionResult(output, minMaxScaler.getPredictionCol(), Vectors.dense(0.75, 0.225)); + } + + @Test + public void testGetModelData() throws Exception { + MinMaxScaler minMaxScaler = new MinMaxScaler(); + MinMaxScalerModel minMaxScalerModel = minMaxScaler.fit(trainDataTable); + Table modelData = minMaxScalerModel.getModelData()[0]; + DataStream<Row> output = tEnv.toDataStream(modelData); + assertEquals("minVector", modelData.getResolvedSchema().getColumnNames().get(0)); + assertEquals("maxVector", modelData.getResolvedSchema().getColumnNames().get(1)); + List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect()); + MinMaxScalerModelData data = Review comment: Hmm... it is redundant to instantiate a `MinMaxScalerModelData` with the expected min/max vectors and then invoke assertEquals on these two class member variables. Would it be simpler to do the following directly? ``` assertEquals(new DenseVector(new double[] {0.0, 0.0}), modelRows.get(0).getField(0)); assertEquals(new DenseVector(new double[] {200.0, 400.0}), modelRows.get(0).getField(1)); ``` ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModelData.java ########## @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputView; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputView; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link MinMaxScalerModel}. + * + * <p>This class also provides methods to convert model data from Table to a data stream, and + * classes to save/load model data. + */ +public class MinMaxScalerModelData { + public DenseVector minVector; + + public DenseVector maxVector; + + public MinMaxScalerModelData() {} + + public MinMaxScalerModelData(DenseVector minVector, DenseVector maxVector) { + this.minVector = minVector; + this.maxVector = maxVector; + } + + /** + * Converts the table model to a data stream. + * + * @param modelDataTable The table model data. + * @return The data stream model data. + */ + public static DataStream<MinMaxScalerModelData> getModelDataStream(Table modelDataTable) { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment(); + return tEnv.toDataStream(modelDataTable) + .map( + x -> + new MinMaxScalerModelData( + (DenseVector) x.getField(0), (DenseVector) x.getField(1))); + } + + /** Encoder for {@link MinMaxScalerModelData}. */ + public static class ModelDataEncoder implements Encoder<MinMaxScalerModelData> { + @Override + public void encode(MinMaxScalerModelData minMaxScalerModelData, OutputStream outputStream) Review comment: Super nits: all other classes (except `KnnModelData`) uses `modelData` as the input parameter name of this method. Any chance we can do the same here for consistency and shorten the code a little bit? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScalerModel.java ########## @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A Model which do a minMax scaler operation using the model data computed by {@link MinMaxScaler}. + */ +public class MinMaxScalerModel + implements Model<MinMaxScalerModel>, MinMaxScalerParams<MinMaxScalerModel> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public MinMaxScalerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public MinMaxScalerModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> data = tEnv.toDataStream(inputs[0]); + DataStream<MinMaxScalerModelData> minMaxScalerModel = + MinMaxScalerModelData.getModelDataStream(modelDataTable); + final String broadcastModelKey = "broadcastModelKey"; + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + ExternalTypeInfo.of(DenseVector.class)), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getPredictionCol())); + DataStream<Row> output = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(data), + Collections.singletonMap(broadcastModelKey, minMaxScalerModel), + inputList -> { + DataStream input = inputList.get(0); + return input.map( + new PredictOutputFunction( + broadcastModelKey, + getMax(), + getMin(), + getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(output)}; + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + MinMaxScalerModelData.getModelDataStream(modelDataTable), + path, + new MinMaxScalerModelData.ModelDataEncoder()); + } + + /** + * Loads model data from path. + * + * @param env Stream execution environment. + * @param path Model path. + * @return MinMaxScalerModel model. + */ + public static MinMaxScalerModel load(StreamExecutionEnvironment env, String path) + throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + MinMaxScalerModel model = ReadWriteUtils.loadStageParam(path); + DataStream<MinMaxScalerModelData> modelData = + ReadWriteUtils.loadModelData( + env, path, new MinMaxScalerModelData.ModelDataDecoder()); + return model.setModelData(tEnv.fromDataStream(modelData)); + } + + /** This operator loads model data and predicts result. */ + private static class PredictOutputFunction extends RichMapFunction<Row, Row> { + private final String featureCol; + private final String broadcastKey; + private MinMaxScalerModelData minMaxScalerModelData; Review comment: Given that the code already stores `scaleVector` and `offsetVector`, do we still need to keep `minMaxScalerModelData` as a class member variable here? Same for `minVector`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/feature/minmaxscaler/MinMaxScaler.java ########## @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.minmaxscaler; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.RichMapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.iteration.operator.OperatorStateUtils; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.runtime.state.StateSnapshotContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * An Estimator which implements the MinMaxScaler algorithm. This algorithm rescales feature values + * to a common range [min, max] which defined by user. + * + * <blockquote> + * + * $$ Rescaled(value) = \frac{value - E_{min}}{E_{max} - E_{min}} * (max - min) + min $$ + * + * </blockquote> + * + * <p>For the case \(E_{max} == E_{min}\), \(Rescaled(value) = 0.5 * (max + min)\). + * + * <p>See https://en.wikipedia.org/wiki/Feature_scaling#Rescaling_(min-max_normalization). + */ +public class MinMaxScaler + implements Estimator<MinMaxScaler, MinMaxScalerModel>, MinMaxScalerParams<MinMaxScaler> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public MinMaxScaler() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public MinMaxScalerModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + final String featureCol = getFeaturesCol(); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<DenseVector> features = + tEnv.toDataStream(inputs[0]) + .map( + (MapFunction<Row, DenseVector>) + value -> (DenseVector) value.getField(featureCol)); + DataStream<DenseVector> minMaxValues = + features.transform( + "reduceInEachPartition", + features.getType(), + new MinMaxReduceFunctionOperator()) + .transform( + "reduceInFinalPartition", + features.getType(), + new MinMaxReduceFunctionOperator()) + .setParallelism(1); + DataStream<MinMaxScalerModelData> modelData = + DataStreamUtils.mapPartition( + minMaxValues, + new RichMapPartitionFunction<DenseVector, MinMaxScalerModelData>() { + @Override + public void mapPartition( + Iterable<DenseVector> values, + Collector<MinMaxScalerModelData> out) { + Iterator<DenseVector> iter = values.iterator(); + DenseVector minVector = iter.next(); + DenseVector maxVector = iter.next(); + out.collect(new MinMaxScalerModelData(minVector, maxVector)); + } + }); + + MinMaxScalerModel model = + new MinMaxScalerModel().setModelData(tEnv.fromDataStream(modelData)); + ReadWriteUtils.updateExistingParams(model, getParamMap()); + return model; + } + + /** + * A stream operator to compute the min and max values in each partition of the input bounded + * data stream. + */ + private static class MinMaxReduceFunctionOperator extends AbstractStreamOperator<DenseVector> + implements OneInputStreamOperator<DenseVector, DenseVector>, BoundedOneInput { + private ListState<DenseVector> minState; + private ListState<DenseVector> maxState; + + private DenseVector minVector; + private DenseVector maxVector; + + @Override + public void endInput() { + if (minVector != null) { + output.collect(new StreamRecord<>(minVector)); + output.collect(new StreamRecord<>(maxVector)); + } + } + + @Override + public void processElement(StreamRecord<DenseVector> streamRecord) { + DenseVector currentValue = streamRecord.getValue(); + if (minVector == null) { + int vecSize = currentValue.size(); + minVector = new DenseVector(vecSize); + maxVector = new DenseVector(vecSize); + System.arraycopy(currentValue.values, 0, minVector.values, 0, vecSize); + System.arraycopy(currentValue.values, 0, maxVector.values, 0, vecSize); + } else { + Preconditions.checkArgument( + currentValue.size() == maxVector.size(), + "CurrentValue should has same size with maxVector."); + for (int i = 0; i < currentValue.size(); ++i) { + minVector.values[i] = Math.min(minVector.values[i], currentValue.values[i]); + maxVector.values[i] = Math.max(maxVector.values[i], currentValue.values[i]); + } + } + } + + @Override + @SuppressWarnings("unchecked") + public void initializeState(StateInitializationContext context) throws Exception { + super.initializeState(context); + minState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "minState", + getOperatorConfig() + .getTypeSerializerIn( + 0, getClass().getClassLoader()))); + maxState = + context.getOperatorStateStore() + .getListState( + new ListStateDescriptor<>( + "maxState", + getOperatorConfig() + .getTypeSerializerIn( + 0, getClass().getClassLoader()))); + + OperatorStateUtils.getUniqueElement(minState, "minState").ifPresent(x -> minVector = x); + OperatorStateUtils.getUniqueElement(maxState, "maxState").ifPresent(x -> maxVector = x); + } + + @Override + @SuppressWarnings("unchecked") + public void snapshotState(StateSnapshotContext context) throws Exception { + super.snapshotState(context); + minState.clear(); + maxState.clear(); + if (minVector != null) { + minState.add(minVector); + } + if (maxVector != null) { Review comment: nits: instead of using two `if` statements, would it be simpler to do the following? ``` if (minVector != null) { minState.add(minVector); maxState.add(maxVector); } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
